import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.metrics.pairwise import pairwise_kernels
import matplotlib.pyplot as plt
### modified coupling layer class
import torch
import torch.nn as nn
import torch.nn.functional as F
class CouplingLayer(nn.Module):
def __init__(self, input_size, hidden_size):
super(CouplingLayer, self).__init__()
# Neural networks for the first half of the dimensions
self.fc1 = nn.Linear(input_size // 2, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
# Translation coefficient
self.fc3 = nn.Linear(hidden_size, input_size // 2)
# Scaling coefficient
self.fc4 = nn.Linear(hidden_size, input_size // 2)
def forward(self, x):
# Split the input into two halves
x_a, x_b = x.chunk(2, dim=1)
# Apply neural network to calculate coefficients
h = F.relu(self.fc1(x_a))
h = F.relu(self.fc2(h))
translation = self.fc3(h)
scaling_before_exp = torch.tanh(self.fc4(h))### taking the tanh
scaling = torch.exp(scaling_before_exp)
# Apply the affine transformation
y_b = x_b * scaling + translation
# Concatenate the transformed halves
y = torch.cat([x_a, y_b], dim=1)
return y, scaling_before_exp
def backward(self, y):
# Split the input into two halves
y_a, y_b = y.chunk(2, dim=1)
# Apply neural network to calculate coefficients (reverse)
h = F.relu(self.fc1(y_a))
h = F.relu(self.fc2(h))
translation = self.fc3(h)
scaling_before_exp = self.fc4(h)
scaling = torch.exp(torch.tanh(scaling_before_exp))
# Reverse the operations to reconstruct the original input
x_a = y_a
x_b = (y_b - translation) / scaling
# Concatenate the reconstructed halves
x = torch.cat([x_a, x_b], dim=1)
return x
class RealNVP(nn.Module):
def __init__(self, input_size, hidden_size, blocks):
super(RealNVP, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.blocks = blocks
# List of coupling layers
self.coupling_layers = nn.ModuleList([
CouplingLayer(input_size, hidden_size) for _ in range(blocks)
])
# List to store orthonormal matrices
self.orthonormal_matrices = [self._get_orthonormal_matrix(input_size) for _ in range(blocks)]
# List to store scaling_before_exp for each block
self.scaling_before_exp_list = []
def _get_orthonormal_matrix(self, size):
# Function to generate a random orthonormal matrix
w = torch.randn(size, size)
q, _ = torch.linalg.qr(w,'reduced')
return q
def forward_realnvp(self, x):
scaling_before_exp_list = []
for i in range(self.blocks):
# Apply random orthonormal matrix
x = torch.matmul(x, self.orthonormal_matrices[i])
# Apply coupling layer
x, scaling_before_exp = self.coupling_layers[i].forward(x)
scaling_before_exp_list.append(scaling_before_exp)
self.scaling_before_exp_list = scaling_before_exp_list
return x
def encode(self, x):
# Encoding is the forward pass through the RealNVP model
return self.forward_realnvp(x)
def decode(self, z):
# Reverse transformations for decoding
for i in reversed(range(self.blocks)):
# Apply coupling layer (reverse)
z = self.coupling_layers[i].backward(z)
# Apply random orthonormal matrix (reverse)
z = torch.matmul(z, self.orthonormal_matrices[i].t())
return z
def sample(self, num_samples=1000):
# Generate random samples from a standard normal distribution
with torch.no_grad():
z = torch.randn(num_samples, self.input_size)
# Apply the reverse transformations (decoder) to generate synthetic samples
synthetic_samples = self.decode(z)
return synthetic_samples
### defining our loss function
def calculate_loss(transformed_x, scaling_before_exp_list, dataset_length):
"""
Calculate the loss for the RealNVP model.
Args:
- transformed_x (tensor): Transformed data produced by the RealNVP model.
- scaling_before_exp_list (list): List of scaling_before_exp values for each block.
- dataset_length (int): The length of the dataset.
Returns:
- loss (tensor): The calculated loss value.
"""
# Calculate the first term of the loss (negative log-likelihood term)
first_term = 0.5*torch.sum(transformed_x**2)
second_term= -torch.sum(torch.cat(scaling_before_exp_list))#torch.sum(torch.stack(model.scaling_before_exp_list), dim=0)
# Calculate the total loss
loss = (first_term + second_term) / dataset_length
return loss
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
def train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=0.001, print_after=1):
"""
Train the RealNVP model and evaluate on a validation dataset.
Args:
- model (RealNVP): The RealNVP model to be trained.
- train_loader (DataLoader): DataLoader for the training dataset.
- val_loader (DataLoader): DataLoader for the validation dataset.
- num_epochs (int): Number of training epochs.
- lr (float): Learning rate for the optimizer.
- print_after (int): Number of epochs after which to print the training and validation loss.
Returns:
- train_losses (list): List of training losses for each epoch.
- val_losses (list): List of validation losses for each epoch.
"""
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
train_losses = [] # List to store training losses
val_losses = [] # List to store validation losses
for epoch in range(num_epochs):
total_train_loss = 0.0
# Training phase
model.train() # Set the model to training mode
for data in train_loader:
inputs= data
# Zero the gradients
optimizer.zero_grad()
# Forward pass (encoding)
encoded = model.encode(inputs)
# Loss calculation
train_loss = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))
# Backward pass (gradient computation)
train_loss.backward()
### added recently: clip the gradients
clip_grad_norm_(model.parameters(), max_norm=1.0) # Adjust max_norm as needed
# Update weights
optimizer.step()
total_train_loss += train_loss.item()
# Average training loss for the epoch
average_train_loss = total_train_loss / len(train_loader)
# Validation phase
if val_loader is not None:
model.eval() # Set the model to evaluation mode
total_val_loss = 0.0
with torch.no_grad():
for val_data in val_loader:
val_inputs = val_data
# Forward pass (encoding) for validation
val_encoded = model.encode(val_inputs)
# Loss calculation for validation
val_loss = calculate_loss(val_encoded, model.scaling_before_exp_list, len(val_loader))
total_val_loss += val_loss.item()
# Average validation loss for the epoch
average_val_loss = total_val_loss / len(val_loader)
# Print training and validation losses together
if (epoch + 1) % print_after == 0:
print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss}, Validation Loss: {average_val_loss}")
# Append losses to the lists
train_losses.append(average_train_loss)
val_losses.append(average_val_loss)
# Set the model back to training mode
model.train()
print("Training complete")
return train_losses, val_losses
# function to plot training and validation losses
def plot_losses(epoch_train_losses, epoch_val_losses, want_log_scale=True):
"""
Plot training and validation losses over epochs on a log scale.
Args:
epoch_train_losses (list): List of training losses for each epoch.
epoch_val_losses (list): List of validation losses for each epoch.
"""
epochs = range(1, len(epoch_train_losses) + 1)
plt.plot(epochs, epoch_train_losses, label='Training Loss')
plt.plot(epochs, epoch_val_losses, label='Validation Loss')
if want_log_scale:
plt.yscale('log') # Set the y-axis to a logarithmic scale
plt.title('Training and Validation reconstruction Losses (Log Scale)',fontsize=10)
else:
plt.title('Training and Validation reconstruction Losses',fontsize=10)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
def visualize_synthetic_data(original_data, synthetic_data):
"""
Scatter plot to visualize the original and synthetic data in 2D.
Args:
- original_data (torch.Tensor): Original data (2D).
- synthetic_data (torch.Tensor): Synthetic data (2D).
Returns:
- None: Displays the scatter plot.
"""
# Ensure both original and synthetic data are converted to numpy arrays
with torch.no_grad():
# Convert PyTorch tensors to numpy arrays
original_np = original_data.numpy()
synthetic_np = synthetic_data.numpy()
# Scatter plot of original and synthetic data
plt.scatter(original_np[:, 0], original_np[:, 1], label='Original', alpha=0.5)
plt.scatter(synthetic_np[:, 0], synthetic_np[:, 1], label='Synthetic', alpha=0.5)
# Add labels and title
plt.xlabel("dimension-1")
plt.ylabel("dimension-2")
plt.title('Original vs Synthetic Data')
# Add legend
plt.legend()
# Display the plot
#plt.show()
def plot_code_distribution(model, test_loader, num_samples=1000):
"""
Plot the code distribution obtained by applying the trained RealNVP model to a test dataset.
Args:
- model (RealNVP): Trained RealNVP model.
- test_loader (DataLoader): DataLoader for the test dataset.
- num_samples (int): Number of samples to visualize.
Returns:
None (displays the plot).
"""
model.eval() # Set the model to evaluation mode
with torch.no_grad():
# Concatenate multiple batches to obtain more samples
test_samples = torch.cat([batch for batch in test_loader], dim=0)
# Assuming your model has an `encode` method
code_samples = model.encode(test_samples[:num_samples])
# Convert PyTorch tensor to numpy array
code_np = code_samples.numpy()
# Scatter plot of code distribution
plt.scatter(code_np[:, 0], code_np[:, 1], label='Code Distribution', alpha=0.5)
plt.xlabel("Code Dimension 1")
plt.ylabel("Code Dimension 2")
plt.title('Code Distribution')
plt.legend()
#plt.show()
dataset_sizes = [ 100, 200,300,400,500,600,700,800,900, 1000, 5000]
# Generate datasets of varying sizes
train_datasets = {}
val_datasets = {}
datasets = {}
for size in dataset_sizes:
X, y = make_moons(n_samples=size, noise=0.1)
datasets[size] = {'X': X, 'y': y}
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
train_datasets[size] = {'X': torch.FloatTensor(X_train), 'y': y_train}
val_datasets[size] = {'X': torch.FloatTensor(X_test), 'y': y_test}
# # Visualize the training datasets
# plt.figure(figsize=(12, 8))
# for i, size in enumerate(dataset_sizes, 1):
# plt.subplot(2, 2, i)
# plt.scatter(datasets[size]['X'][:, 0], datasets[size]['X'][:, 1], c=datasets[size]['y'])
# plt.title(f'Dataset Size: {size}')
# plt.show()
### creating the dataloader for the make moons dataset
from torch.utils.data import DataLoader, TensorDataset
### Trial run
import numpy as np
input_size=2
hidden_size=200### do I really need this to be this large?
blocks=10 ####### larger number of blocks ensures that the code distribution is indeed gaussian
print_after=1
#### data for the two-moons model
dataset_size=5000
batch_size=32
data_considered=train_datasets[dataset_size]['X']
print("shape of the data_considered"); print(data_considered.shape)
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=batch_size, shuffle=True)
val_loader= torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=batch_size, shuffle=True)
####
### instantiate the model
model= RealNVP(input_size=2, hidden_size= hidden_size, blocks=blocks)
## train the model
train_losses, val_losses= train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=0.0001, print_after=1)
#train_inn(model, train_loader, num_epochs=500, lr=0.01, print_after=10)
#1. 0.00005 num_epochs=20,dataset_size=5000, batchsize=64: right now I have kept blocks=10: code distribution was more gaussian and generated data was comparitively better
# plotting the loss
plot_losses(train_losses[3:], val_losses[3:], want_log_scale=0)
plt.show()
# Example usage:
plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
plt.show()
### plot the synthetic data and the original data
synthetic_data=model.sample(num_samples=1000)
visualize_synthetic_data(train_datasets[1000]['X'], synthetic_data)
plt.show()
shape of the data_considered torch.Size([3500, 2]) Epoch 1/10, Training Loss: -0.069945588813756, Validation Loss: -0.2947036237158674 Epoch 2/10, Training Loss: -0.13655666284559465, Validation Loss: -0.3800994145109298 Epoch 3/10, Training Loss: -0.15644234075126323, Validation Loss: -0.4005805108141392 Epoch 4/10, Training Loss: -0.15004104049876332, Validation Loss: -0.3940854034525283 Epoch 5/10, Training Loss: -0.17315212689678777, Validation Loss: -0.42223459038328615 Epoch 6/10, Training Loss: -0.17323222608220848, Validation Loss: -0.34632147388889434 Epoch 7/10, Training Loss: -0.17355674239383503, Validation Loss: -0.4644995800992276 Epoch 8/10, Training Loss: -0.18419327680021524, Validation Loss: -0.3325619826767039 Epoch 9/10, Training Loss: -0.18445073308592494, Validation Loss: -0.45367521745093325 Epoch 10/10, Training Loss: -0.18611254140057348, Validation Loss: -0.40497843128569583 Training complete
import numpy as np
from sklearn.metrics.pairwise import pairwise_kernels
import matplotlib.pyplot as plt
def compute_mmd(X, Y, kernel='rbf', gamma=None):
"""
Compute Maximum Mean Discrepancy (MMD) between two datasets.
Parameters:
- X, Y: Input datasets (numpy arrays).
- kernel: Kernel function to use ('linear', 'rbf', etc.).
- gamma: Kernel coefficient for 'rbf' kernel (if applicable).
Returns:
- mmd: Maximum Mean Discrepancy value.
"""
X = X.detach().numpy() if isinstance(X, torch.Tensor) else X
Y = Y.detach().numpy() if isinstance(Y, torch.Tensor) else Y
# Compute pairwise kernel matrices
K_xx = pairwise_kernels(X, X, metric=kernel, gamma=gamma)
K_yy = pairwise_kernels(Y, Y, metric=kernel, gamma=gamma)
K_xy = pairwise_kernels(X, Y, metric=kernel, gamma=gamma)
# Compute MMD
mmd = np.mean(K_xx) + np.mean(K_yy) - 2 * np.mean(K_xy)
return mmd
### Input_size=2, hidden_size=200, lr=0.0001, num_epochs=10: Fixed
def train_and_plot_for_different_block_sizes(blocks_values, train_loader, val_loader):
results = []
for blocks in blocks_values:
print(f"\nTraining for blocks={blocks}")
# Instantiate the model
model = RealNVP(input_size=2, hidden_size=200, blocks=blocks)
# Train the model
train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=0.0001, print_after=100)
# Plot code distribution
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.xlim(-3,3)
plt.ylim(-3,3)
plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
plt.title(f'Code Distribution (blocks={blocks})')
# Plot synthetic data
plt.subplot(1, 2, 2)
synthetic_data = model.sample(num_samples=1000)
visualize_synthetic_data(train_datasets[1000]['X'], synthetic_data)
plt.title(f'Synthetic Data (blocks={blocks})')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
# # Calculate MMD score
mmd_value = compute_mmd(val_loader.dataset, synthetic_data)
print(f'MMD Score (blocks={blocks}): {mmd_value:.4f}')
results.append((blocks, mmd_value))
# # Plot MMD scores
plt.figure(figsize=(8, 5))
blocks, mmd_values = zip(*results)
plt.plot(blocks, mmd_values, marker='o')
plt.title('MMD Scores for Different Number of Blocks')
plt.yscale('log')
plt.xlabel('Blocks')
plt.ylabel('MMD Score')
plt.show()
######### for different coupling blocks
dataset_size=5000
print(f"For fixed dataset_size={dataset_size}, hidden_size=200, lr=0.0001, num_epochs=10")
batch_size=32
data_considered=train_datasets[dataset_size]['X']
print("shape of the data_considered"); print(data_considered.shape)
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=batch_size, shuffle=True)
val_loader= torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=batch_size, shuffle=True)
####
blocks_values_to_try = [1,2,10,15]
train_and_plot_for_different_block_sizes(blocks_values_to_try, train_loader, val_loader)
For fixed dataset_size=5000, hidden_size=200, lr=0.0001, num_epochs=10 shape of the data_considered torch.Size([3500, 2]) Training for blocks=1 Training complete
MMD Score (blocks=1): 0.0859 Training for blocks=2 Training complete
MMD Score (blocks=2): 0.0092 Training for blocks=10 Training complete
MMD Score (blocks=10): 0.0040 Training for blocks=15 Training complete
MMD Score (blocks=15): 0.0008
Number of coupling blocks:
def train_and_plot_for_different_dataset_sizes(dataset_sizes, train_loader, val_loader):
results = []
for dataset_size in dataset_sizes:
print(f"\nTraining for dataset_size={dataset_size}")
# Instantiate the model
model = RealNVP(input_size=2, hidden_size=200, blocks=10) # Fix other parameters
# Create data loader for the current dataset size
data_considered = train_datasets[dataset_size]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=32, shuffle=True)
# Train the model
train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=0.0001, print_after=100)
# Plot code distribution
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
plt.title(f'Code Distribution (dataset_size={dataset_size})')
# Plot synthetic data
plt.subplot(1, 2, 2)
synthetic_data = model.sample(num_samples=1000)
visualize_synthetic_data(train_datasets[dataset_size]['X'], synthetic_data)
plt.title(f'Synthetic Data (dataset_size={dataset_size})')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
# Calculate MMD score
mmd_value = compute_mmd(val_loader.dataset, synthetic_data)
print(f'MMD Score (dataset_size={dataset_size}): {mmd_value:.4f}')
results.append((dataset_size, mmd_value))
# Plot MMD scores
plt.figure(figsize=(8, 5))
dataset_sizes, mmd_values = zip(*results)
plt.plot(dataset_sizes, mmd_values, marker='o')
plt.title('MMD Scores for Different Dataset Sizes')
plt.yscale('log')
plt.xlabel('Dataset Size')
plt.ylabel('MMD Score')
plt.show()
# Different dataset sizes to try
print(f"For fixed number of blocks=10, hidden_size=200, lr=0.0001, num_epochs=10")
dataset_sizes_to_try = [ 100, 200,300,400,500,600,700,800,900, 1000, 5000]
train_and_plot_for_different_dataset_sizes(dataset_sizes_to_try, train_loader, val_loader)
For fixed number of blocks=10, hidden_size=200, lr=0.0001, num_epochs=10 Training for dataset_size=100 Training complete
MMD Score (dataset_size=100): 0.0202 Training for dataset_size=200 Training complete
MMD Score (dataset_size=200): 0.0222 Training for dataset_size=300 Training complete
MMD Score (dataset_size=300): 0.0111 Training for dataset_size=400 Training complete
MMD Score (dataset_size=400): 0.0030 Training for dataset_size=500 Training complete
MMD Score (dataset_size=500): 0.0083 Training for dataset_size=600 Training complete
MMD Score (dataset_size=600): 0.0037 Training for dataset_size=700 Training complete
MMD Score (dataset_size=700): 0.0042 Training for dataset_size=800 Training complete
MMD Score (dataset_size=800): 0.0046 Training for dataset_size=900 Training complete
MMD Score (dataset_size=900): 0.0033 Training for dataset_size=1000 Training complete
MMD Score (dataset_size=1000): 0.0052 Training for dataset_size=5000 Training complete
MMD Score (dataset_size=5000): 0.0015
We did a few trial runs (of the cell above) for datasets of different sizes. We observed that in general, The quality of synthetic dataset so generated increases with increase in the size of the dataset.
def train_and_plot_for_different_learning_rates(learning_rates, dataset_size=1000, block_size=10):
results = []
for lr in learning_rates:
print(f"\nTraining for learning rate={lr}")
# Instantiate the model
model = RealNVP(input_size=2, hidden_size=200, blocks=block_size) # Fix other parameters
# Create data loader for the fixed dataset size
data_considered = train_datasets[dataset_size]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=32, shuffle=True)
# Train the model
train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=lr, print_after=1)
# Plot code distribution
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
plt.title(f'Code Distribution (learning_rate={lr})')
# Plot synthetic data
plt.subplot(1, 2, 2)
synthetic_data = model.sample(num_samples=1000)
visualize_synthetic_data(train_datasets[dataset_size]['X'], synthetic_data)
plt.title(f'Synthetic Data (learning_rate={lr})')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
# Calculate MMD score
mmd_value = compute_mmd(val_loader.dataset, synthetic_data)
print(f'MMD Score (learning_rate={lr}): {mmd_value:.4f}')
results.append((lr, mmd_value))
# Plot MMD scores
plt.figure(figsize=(8, 5))
learning_rates, mmd_values = zip(*results)
plt.plot(learning_rates, mmd_values, marker='o')
plt.title('MMD Scores for Different Learning Rates')
plt.xlabel('Learning Rate')
plt.ylabel('MMD Score')
plt.xscale('log') # Use a logarithmic scale for better visualization of different orders of magnitude
plt.yscale('log')
plt.show()
# Different learning rates to try
print("For fixed number of blocks=10, hidden_size=200, dataset_size=1000, num_epochs=10")
learning_rates_to_try = [0.01,0.005,0.0005,0.0001,0.000005]
train_and_plot_for_different_learning_rates(learning_rates_to_try, dataset_size=1000, block_size=10)
For fixed number of blocks=10, hidden_size=200, dataset_size=1000, num_epochs=10 Training for learning rate=0.01
--------------------------------------------------------------------------- KeyError Traceback (most recent call last) c:\Users\luke\OneDrive\Dokumente\UniHeidelberg\Master\Semester3\Generative Neural Networks\code\Exercise_3_GNN_for_science.ipynb Cell 23 line 5 <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=53'>54</a> print("For fixed number of blocks=10, hidden_size=200, dataset_size=1000, num_epochs=10") <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=54'>55</a> learning_rates_to_try = [0.01,0.005,0.0005,0.0001,0.000005] ---> <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=55'>56</a> train_and_plot_for_different_learning_rates(learning_rates_to_try, dataset_size=1000, block_size=10) c:\Users\luke\OneDrive\Dokumente\UniHeidelberg\Master\Semester3\Generative Neural Networks\code\Exercise_3_GNN_for_science.ipynb Cell 23 line 1 <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=7'>8</a> model = RealNVP(input_size=2, hidden_size=200, blocks=block_size) # Fix other parameters <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=9'>10</a> # Create data loader for the fixed dataset size ---> <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=10'>11</a> data_considered = train_datasets[dataset_size]['X'] <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=11'>12</a> train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True) <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#X31sZmlsZQ%3D%3D?line=12'>13</a> val_loader = torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=32, shuffle=True) KeyError: 1000
We observed that the lower learning rates were suited best for the real NVP (INN) model. We found the learning rate of 1e-4 the most suitable for this case.
def train_and_plot_for_different_epochs(epochs_list, dataset_size=1000, block_size=10, lr=0.0001):
results = []
for num_epochs in epochs_list:
print(f"\nTraining for num_epochs={num_epochs}")
# Instantiate the model
model = RealNVP(input_size=2, hidden_size=200, blocks=block_size) # Fix other parameters
# Create data loader for the fixed dataset size
data_considered = train_datasets[dataset_size]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_datasets[dataset_size]['X'], batch_size=32, shuffle=True)
# Train the model
train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=num_epochs, lr=lr, print_after=2)
# Plot code distribution
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
plt.title(f'Code Distribution (num_epochs={num_epochs})')
# Plot synthetic data
plt.subplot(1, 2, 2)
synthetic_data = model.sample(num_samples=1000)
visualize_synthetic_data(train_datasets[dataset_size]['X'], synthetic_data)
plt.title(f'Synthetic Data (num_epochs={num_epochs})')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
# Calculate MMD score
mmd_value = compute_mmd(val_loader.dataset, synthetic_data)
print(f'MMD Score (num_epochs={num_epochs}): {mmd_value:.4f}')
results.append((num_epochs, mmd_value))
# Plot MMD scores
plt.figure(figsize=(8, 5))
num_epochs_values, mmd_values = zip(*results)
plt.plot(num_epochs_values, mmd_values, marker='o')
plt.title('MMD Scores for Different Numbers of Epochs')
plt.xlabel('Number of Epochs')
plt.ylabel('MMD Score')
plt.show()
# Different numbers of epochs to try
print("For fixed number of blocks=10, hidden_size=200, dataset_size=5000, lr=0.0001")
epochs_to_try = [10,40,70,100]
train_and_plot_for_different_epochs(epochs_to_try, dataset_size=5000, block_size=10, lr=0.0001)
For fixed number of blocks=10, hidden_size=200, dataset_size=5000, lr=0.0001 Training for num_epochs=10 Epoch 2/10, Training Loss: -0.14358886231414297, Validation Loss: -0.2803282976742497 Epoch 4/10, Training Loss: -0.17533571784469215, Validation Loss: -0.4116998964801748 Epoch 6/10, Training Loss: -0.18580710415474394, Validation Loss: -0.4493791166138142 Epoch 8/10, Training Loss: -0.1960323669609021, Validation Loss: -0.4617179879482756 Epoch 10/10, Training Loss: -0.20351748117669063, Validation Loss: -0.48170990528578456 Training complete
MMD Score (num_epochs=10): 0.0073 Training for num_epochs=40 Epoch 2/40, Training Loss: -0.1378864387388934, Validation Loss: -0.3690203459973031 Epoch 4/40, Training Loss: -0.1580163512632928, Validation Loss: -0.38746916185668173 Epoch 6/40, Training Loss: -0.17779654987495053, Validation Loss: -0.43654367295985524 Epoch 8/40, Training Loss: -0.1889630897986618, Validation Loss: -0.4552901840590416 Epoch 10/40, Training Loss: -0.19047909550030123, Validation Loss: -0.42791273777789257 Epoch 12/40, Training Loss: -0.19354358423839915, Validation Loss: -0.4706975151883795 Epoch 14/40, Training Loss: -0.2010810967704112, Validation Loss: -0.46784974095669196 Epoch 16/40, Training Loss: -0.2049939930523661, Validation Loss: -0.44716153975496903 Epoch 18/40, Training Loss: -0.20869793099435893, Validation Loss: -0.49485309897585117 Epoch 20/40, Training Loss: -0.20657219410958616, Validation Loss: -0.4554901722263783 Epoch 22/40, Training Loss: -0.20321482946588235, Validation Loss: -0.46633801751948417 Epoch 24/40, Training Loss: -0.21119618913666768, Validation Loss: -0.4876342476048368 Epoch 26/40, Training Loss: -0.20632553188638253, Validation Loss: -0.46331152272351245 Epoch 28/40, Training Loss: -0.21113421246409417, Validation Loss: -0.4772615965376509 Epoch 30/40, Training Loss: -0.21129591796885838, Validation Loss: -0.4784514904022217 Epoch 32/40, Training Loss: -0.21280228827487338, Validation Loss: -0.4888522742276496 Epoch 34/40, Training Loss: -0.2113928150724281, Validation Loss: -0.4908938553739101 Epoch 36/40, Training Loss: -0.21354313719679008, Validation Loss: -0.5119861174137035 Epoch 38/40, Training Loss: -0.21345627887005156, Validation Loss: -0.45827306902154963 Epoch 40/40, Training Loss: -0.21106771359389478, Validation Loss: -0.4497500956058502 Training complete
MMD Score (num_epochs=40): 0.0046 Training for num_epochs=70 Epoch 2/70, Training Loss: -0.13076824862668715, Validation Loss: -0.37503495146619514 Epoch 4/70, Training Loss: -0.1627920740707354, Validation Loss: -0.40823768808486616 Epoch 6/70, Training Loss: -0.1742848001081835, Validation Loss: -0.4192498841501297 Epoch 8/70, Training Loss: -0.19004247720268638, Validation Loss: -0.4279984803275859 Epoch 10/70, Training Loss: -0.19786307911642573, Validation Loss: -0.46302134083940627 Epoch 12/70, Training Loss: -0.20096099491823805, Validation Loss: -0.42286763165859464 Epoch 14/70, Training Loss: -0.20947395664724436, Validation Loss: -0.3848408782418738 Epoch 16/70, Training Loss: -0.21476675871420992, Validation Loss: -0.47927849850756055 Epoch 18/70, Training Loss: -0.21008852278305726, Validation Loss: -0.5150847650588827 Epoch 20/70, Training Loss: -0.21329289078712463, Validation Loss: -0.5035497875923806 Epoch 22/70, Training Loss: -0.21433152136477557, Validation Loss: -0.4633709376162671 Epoch 24/70, Training Loss: -0.21784016591581432, Validation Loss: -0.4931467118415427 Epoch 26/70, Training Loss: -0.2165901618925008, Validation Loss: -0.5155834771217184 Epoch 28/70, Training Loss: -0.21100665737282145, Validation Loss: -0.51318791952539 Epoch 30/70, Training Loss: -0.21846945570273832, Validation Loss: -0.4833254655624958 Epoch 32/70, Training Loss: -0.21690888702869415, Validation Loss: -0.5067412986400279 Epoch 34/70, Training Loss: -0.22221683649854226, Validation Loss: -0.49461661437724497 Epoch 36/70, Training Loss: -0.21972130513326688, Validation Loss: -0.5188743472099304 Epoch 38/70, Training Loss: -0.2154702734117481, Validation Loss: -0.518978476524353 Epoch 40/70, Training Loss: -0.21986391347917644, Validation Loss: -0.5252069203143425 Epoch 42/70, Training Loss: -0.22182731194929642, Validation Loss: -0.5229760636674597 Epoch 44/70, Training Loss: -0.2206318766556003, Validation Loss: -0.5316517213557629 Epoch 46/70, Training Loss: -0.2191770705648444, Validation Loss: -0.5050297106834168 Epoch 48/70, Training Loss: -0.22091747705232012, Validation Loss: -0.5135451238206092 Epoch 50/70, Training Loss: -0.22272868650880726, Validation Loss: -0.507524533474699 Epoch 52/70, Training Loss: -0.22765119698914615, Validation Loss: -0.5047896504402161 Epoch 54/70, Training Loss: -0.21782343824478714, Validation Loss: -0.5143627763745633 Epoch 56/70, Training Loss: -0.22382738732478835, Validation Loss: -0.5082689720265409 Epoch 58/70, Training Loss: -0.2197471954436465, Validation Loss: -0.5090845556969338 Epoch 60/70, Training Loss: -0.2229800802401521, Validation Loss: -0.5346206350529448 Epoch 62/70, Training Loss: -0.22533906894651326, Validation Loss: -0.5416639390143942 Epoch 64/70, Training Loss: -0.2256680274890228, Validation Loss: -0.5050788458357466 Epoch 66/70, Training Loss: -0.22534454664723438, Validation Loss: -0.5375226907273556 Epoch 68/70, Training Loss: -0.2253061113709753, Validation Loss: -0.5076465005491008 Epoch 70/70, Training Loss: -0.22640712132508103, Validation Loss: -0.5100123501838522 Training complete
MMD Score (num_epochs=70): 0.0015 Training for num_epochs=100 Epoch 2/100, Training Loss: -0.13594088523902675, Validation Loss: -0.3640339044814414 Epoch 4/100, Training Loss: -0.1487430640072985, Validation Loss: -0.39018139021193726 Epoch 6/100, Training Loss: -0.1542976822873408, Validation Loss: -0.42238458230140363 Epoch 8/100, Training Loss: -0.16581505732187493, Validation Loss: -0.38754124382629673 Epoch 10/100, Training Loss: -0.17620162418831817, Validation Loss: -0.41299855043279365 Epoch 12/100, Training Loss: -0.1797038776524873, Validation Loss: -0.46067460800739046 Epoch 14/100, Training Loss: -0.1875358109193092, Validation Loss: -0.4512183618672351 Epoch 16/100, Training Loss: -0.19431041539223357, Validation Loss: -0.46909774491127504 Epoch 18/100, Training Loss: -0.19388091472739524, Validation Loss: -0.47596237412158476 Epoch 20/100, Training Loss: -0.19564713624475355, Validation Loss: -0.4955769181251526 Epoch 22/100, Training Loss: -0.19411445575004274, Validation Loss: -0.4284949873356109 Epoch 24/100, Training Loss: -0.20146982812068678, Validation Loss: -0.4502627342305285 Epoch 26/100, Training Loss: -0.2027834393422712, Validation Loss: -0.4436948429396812 Epoch 28/100, Training Loss: -0.20560864592817696, Validation Loss: -0.49413373876125255 Epoch 30/100, Training Loss: -0.20735824114897033, Validation Loss: -0.43310117055761055 Epoch 32/100, Training Loss: -0.20372144708579237, Validation Loss: -0.49495308323109405 Epoch 34/100, Training Loss: -0.2052411500364542, Validation Loss: -0.3276872003569882 Epoch 36/100, Training Loss: -0.20970047583634202, Validation Loss: -0.449070229175243 Epoch 38/100, Training Loss: -0.21172828359254212, Validation Loss: -0.4277269909990595 Epoch 40/100, Training Loss: -0.21106311015107415, Validation Loss: -0.45086967691462093 Epoch 42/100, Training Loss: -0.21024575511162932, Validation Loss: -0.4654041325792353 Epoch 44/100, Training Loss: -0.21036843491548843, Validation Loss: -0.47354090942981397 Epoch 46/100, Training Loss: -0.2124092212793502, Validation Loss: -0.5061322875479435 Epoch 48/100, Training Loss: -0.21252718266438353, Validation Loss: -0.48068189684380874 Epoch 50/100, Training Loss: -0.21141986224631018, Validation Loss: -0.49343108750404197 Epoch 52/100, Training Loss: -0.21551176594062285, Validation Loss: -0.49017775883065895 Epoch 54/100, Training Loss: -0.21384021821008486, Validation Loss: -0.45886730925833924 Epoch 56/100, Training Loss: -0.2167235734110529, Validation Loss: -0.49913877121945643 Epoch 58/100, Training Loss: -0.21804992294108325, Validation Loss: -0.4848848796905355 Epoch 60/100, Training Loss: -0.21321812539615415, Validation Loss: -0.4758980394677913 Epoch 62/100, Training Loss: -0.21669397191567855, Validation Loss: -0.4722903497675632 Epoch 64/100, Training Loss: -0.2187076737934893, Validation Loss: -0.4830910129115937 Epoch 66/100, Training Loss: -0.21772300817749718, Validation Loss: -0.5008368403353589 Epoch 68/100, Training Loss: -0.21606271212751216, Validation Loss: -0.5120836623171543 Epoch 70/100, Training Loss: -0.2175406933508136, Validation Loss: -0.48003573049890236 Epoch 72/100, Training Loss: -0.21688812923702327, Validation Loss: -0.5070201881388401 Epoch 74/100, Training Loss: -0.21162096383896742, Validation Loss: -0.5203415630979741 Epoch 76/100, Training Loss: -0.2169081981886517, Validation Loss: -0.483633760442125 Epoch 78/100, Training Loss: -0.22372531721537764, Validation Loss: -0.5064087939706254 Epoch 80/100, Training Loss: -0.21813252635977484, Validation Loss: -0.5143103402979831 Epoch 82/100, Training Loss: -0.2172615073621273, Validation Loss: -0.47639415841153326 Epoch 84/100, Training Loss: -0.21871006509119814, Validation Loss: -0.46699031727745177 Epoch 86/100, Training Loss: -0.2207058498678221, Validation Loss: -0.5056525043984677 Epoch 88/100, Training Loss: -0.22040676318786362, Validation Loss: -0.5213642380339034 Epoch 90/100, Training Loss: -0.21426767673004757, Validation Loss: -0.5214954779503194 Epoch 92/100, Training Loss: -0.21918462758714502, Validation Loss: -0.5133034390337924 Epoch 94/100, Training Loss: -0.22393759550018744, Validation Loss: -0.42689299164866246 Epoch 96/100, Training Loss: -0.22194279963997277, Validation Loss: -0.5066487827199571 Epoch 98/100, Training Loss: -0.2178570749407465, Validation Loss: -0.5179648266193715 Epoch 100/100, Training Loss: -0.21775149581530553, Validation Loss: -0.4457393101555236 Training complete
MMD Score (num_epochs=100): 0.0038
Hyperparameters involved:
RealNVP.sample(self,num_samples) that generates the requested number of synthetic points. REPORT the MMD between a testset and generated datapoints: to be mre specific show that visually better results correspond to the smaller MMD.calculate_loss function, I see some training and improvement. IDK whether the loss function I have implemented is correct or not!### conditional coupling layer
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import one_hot
class ConditionalCouplingLayer(nn.Module):
def __init__(self, input_size, hidden_size, condition_size):
"""
Initialize a ConditionalCouplingLayer.
Args:
- input_size (int): Total size of the input data.
- hidden_size (int): Size of the hidden layers in the neural networks.
- condition_size (int): Size of the condition vector (e.g., one-hot encoded label size).
"""
super(ConditionalCouplingLayer, self).__init__()
# Neural networks for the first half of the dimensions
self.fc1 = nn.Linear(input_size // 2 + condition_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
# Translation coefficient
self.fc3 = nn.Linear(hidden_size, input_size // 2)
# Scaling coefficient
self.fc4 = nn.Linear(hidden_size, input_size // 2)
def forward(self, x, condition):
"""
Forward pass through the ConditionalCouplingLayer.
Args:
- x (torch.Tensor): Input data.
- condition (torch.Tensor): Condition vector.
Returns:
- y (torch.Tensor): Transformed data.
- scaling_before_exp (torch.Tensor): Scaling coefficients before the exponential operation.
"""
# Split the input into two halves
x_a, x_b = x.chunk(2, dim=1)
# Concatenate conditions to the first half
x_a_concat = torch.cat([x_a, condition], dim=1)
# Apply neural network to calculate coefficients
h = F.relu(self.fc1(x_a_concat))
h = F.relu(self.fc2(h))
translation = self.fc3(h)
scaling_before_exp = torch.tanh(self.fc4(h))
scaling = torch.exp(scaling_before_exp)
# Apply the affine transformation
y_b = x_b * scaling + translation
# Concatenate the transformed halves
y = torch.cat([x_a, y_b], dim=1)
return y, scaling_before_exp
def backward(self, y, condition):
"""
Backward pass through the ConditionalCouplingLayer.
Args:
- y (torch.Tensor): Transformed data.
- condition (torch.Tensor): Condition vector.
Returns:
- x (torch.Tensor): Reconstructed original input.
"""
# Split the input into two halves
y_a, y_b = y.chunk(2, dim=1)
# Concatenate conditions to the first half
y_a_concat = torch.cat([y_a, condition], dim=1)
# Apply neural network to calculate coefficients (reverse)
h = F.relu(self.fc1(y_a_concat))
h = F.relu(self.fc2(h))
translation = self.fc3(h)
scaling_before_exp = self.fc4(h)
scaling = torch.exp(torch.tanh(scaling_before_exp))
# Reverse the operations to reconstruct the original input
x_a = y_a
x_b = (y_b - translation) / scaling
# Concatenate the reconstructed halves
x = torch.cat([x_a, x_b], dim=1)
return x
### conditional real NVP class
class ConditionalRealNVP(nn.Module):
def __init__(self, input_size, hidden_size, condition_size, blocks):
"""
Initialize a ConditionalRealNVP model.
Args:
- input_size (int): Total size of the input data.
- hidden_size (int): Size of the hidden layers in the neural networks.
- condition_size (int): Size of the condition vector (e.g., one-hot encoded label size).
- blocks (int): Number of coupling layers in the model.
"""
super(ConditionalRealNVP, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.condition_size = condition_size
self.blocks = blocks
# List of coupling layers
self.coupling_layers = nn.ModuleList([
ConditionalCouplingLayer(input_size, hidden_size, condition_size) for _ in range(blocks)
])
# List to store orthonormal matrices
self.orthonormal_matrices = [self._get_orthonormal_matrix(input_size) for _ in range(blocks)]
# List to store scaling_before_exp for each block
self.scaling_before_exp_list = []
def _get_orthonormal_matrix(self, size):
"""
Generate a random orthonormal matrix.
Args:
- size (int): Size of the matrix.
Returns:
- q (torch.Tensor): Orthonormal matrix.
"""
w = torch.randn(size, size)
q, _ = torch.linalg.qr(w, 'reduced')
return q
def forward_realnvp(self, x, condition):
"""
Forward pass through the ConditionalRealNVP model.
Args:
- x (torch.Tensor): Input data.
- condition (torch.Tensor): Condition vector.
Returns:
- x (torch.Tensor): Transformed data.
"""
scaling_before_exp_list = []
for i in range(self.blocks):
#print("x is:"); print(x)
#print("shape of x is:"); print(x.shape)
x = torch.matmul(x, self.orthonormal_matrices[i])
x, scaling_before_exp = self.coupling_layers[i].forward(x, condition)
scaling_before_exp_list.append(scaling_before_exp)
self.scaling_before_exp_list = scaling_before_exp_list
return x
def decode(self, z, condition):
"""
Reverse transformations to decode data.
Args:
- z (torch.Tensor): Transformed data.
- condition (torch.Tensor): Condition vector.
Returns:
- z (torch.Tensor): Reconstructed original data.
"""
for i in reversed(range(self.blocks)):
z = self.coupling_layers[i].backward(z, condition)
z = torch.matmul(z, self.orthonormal_matrices[i].t())
return z
def sample(self, num_samples=1000, conditions=None):
"""
Generate synthetic samples.
Args:
- num_samples (int): Number of synthetic samples to generate.
- conditions (torch.Tensor): Conditions for generating synthetic samples.
Returns:
- synthetic_samples (torch.Tensor): Synthetic samples.
"""
with torch.no_grad():
z = torch.randn(num_samples, self.input_size)
synthetic_samples = self.decode(z, conditions)
return synthetic_samples
### training_the_conditional_nvp model
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
def train_and_validate_conditional_nvp(model, train_loader, val_loader, num_epochs=10, lr=0.001, print_after=1):
"""
Train the ConditionalRealNVP model and evaluate on a validation dataset.
Args:
- model (ConditionalRealNVP): The ConditionalRealNVP model to be trained.
- train_loader (DataLoader): DataLoader for the training dataset.
- val_loader (DataLoader): DataLoader for the validation dataset.
- num_epochs (int): Number of training epochs.
- lr (float): Learning rate for the optimizer.
- print_after (int): Number of epochs after which to print the training and validation loss.
Returns:
- train_losses (list): List of training losses for each epoch.
- val_losses (list): List of validation losses for each epoch.
"""
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
train_losses = [] # List to store training losses
val_losses = [] # List to store validation losses
for epoch in range(num_epochs):
total_train_loss = 0.0
# Training phase
model.train() # Set the model to training mode
for data, labels in train_loader:
inputs = data
conditions = one_hot(labels, num_classes=model.condition_size).float()
# Zero the gradients
optimizer.zero_grad()
# Forward pass (encoding)
encoded = model.forward_realnvp(inputs, conditions)
# Loss calculation
train_loss = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))
# Backward pass (gradient computation)
train_loss.backward()
# Clip gradients to prevent exploding gradients
clip_grad_norm_(model.parameters(), max_norm=1.0)
# Update weights
optimizer.step()
total_train_loss += train_loss.item()
# Average training loss for the epoch
average_train_loss = total_train_loss / len(train_loader)
# Validation phase
if val_loader is not None:
model.eval() # Set the model to evaluation mode
total_val_loss = 0.0
with torch.no_grad():
for val_data, val_labels in val_loader:
val_inputs = val_data
val_conditions = one_hot(val_labels, num_classes=model.condition_size).float()
# Forward pass (encoding) for validation
val_encoded = model.forward_realnvp(val_inputs, val_conditions)
# Loss calculation for validation
val_loss = calculate_loss(val_encoded, model.scaling_before_exp_list, len(val_loader))
total_val_loss += val_loss.item()
# Average validation loss for the epoch
average_val_loss = total_val_loss / len(val_loader)
# Print training and validation losses together
if (epoch + 1) % print_after == 0:
print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss}, Validation Loss: {average_val_loss}")
# Append losses to the lists
train_losses.append(average_train_loss)
val_losses.append(average_val_loss)
# Set the model back to training mode
model.train()
print("Training complete")
return train_losses, val_losses
### Create the dataset and dataloaders for the conditional NVP model
dataset_sizes = [ 100, 200,300,400,500,600,700,800,900, 1000, 5000]
# Generate datasets of varying sizes
train_datasets = {}
val_datasets = {}
datasets = {}
for size in dataset_sizes:
X, y = make_moons(n_samples=size, noise=0.1)
datasets[size] = {'data': X, 'labels': y}### label imply to which moon does it belong to
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
train_datasets[size] = {'data': torch.FloatTensor(X_train), 'label': y_train}
val_datasets[size] = {'data': torch.FloatTensor(X_test), 'label': y_test}
#### data for the two-moons model
from torch.utils.data import TensorDataset, DataLoader
# Define a custom dataset
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.labels[index]
return x, y
# Define model parameters
input_size = 2
hidden_size = 200
condition_size = 2
blocks = 10
# Initialize the model
conditional_inn_model = ConditionalRealNVP(input_size, hidden_size, condition_size, blocks)
# Define hyperparameters
num_epochs = 10
lr = 0.0001
# Create datasets
dataset_size=1000
train_dataset = CustomDataset(train_datasets[dataset_size]['data'], train_datasets[dataset_size]['label'])
val_dataset = CustomDataset(val_datasets[dataset_size]['data'], val_datasets[dataset_size]['label'])
# Define batch size
batch_size = 32
# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Task 1: Train the Conditional INN
train_loss, val_loss= train_and_validate_conditional_nvp(conditional_inn_model, train_loader, val_loader,
num_epochs=num_epochs, lr=lr, print_after=1)
# # plotting the loss
# plot_losses(train_losses[3:], val_losses[3:], want_log_scale=0)
# plt.show()
# Choose a label for evaluation (e.g., label 0)
eval_condition = torch.tensor([[1, 0]]) # One-hot encoding for label 0
# Repeat the condition vector for each sample
eval_condition = eval_condition.repeat(1000, 1)
with torch.no_grad():
# Generate synthetic samples for the chosen label
synthetic_samples_label_0 = conditional_inn_model.sample(num_samples=1000, conditions=eval_condition)
# # Example usage:
# plot_code_distribution(model=model, test_loader=val_loader, num_samples=1000)
# plt.show()
### plot the synthetic data and the original data
visualize_synthetic_data(train_datasets[1000]['data'], synthetic_samples_label_0)
plt.show()
# Generate synthetic samples for all labels
# Generate synthetic samples for all labels
conditions_all_labels = torch.eye(condition_size) # Assuming one-hot encoding
# Repeat the condition vector for each sample
conditions_all_labels = conditions_all_labels.repeat(1000, 1)
with torch.no_grad():
synthetic_samples_all_labels = conditional_inn_model.sample(num_samples=2000, conditions=conditions_all_labels)
visualize_synthetic_data(train_datasets[1000]['data'], synthetic_samples_all_labels)
plt.show()
First we load the dataset and create different sizes
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import torch
# Load the digits dataset
digits = load_digits()
# Define the dataset percentages
dataset_percentages = [0.1, 0.5, 1]
digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2)
val_dataset = {'X': torch.FloatTensor(X_test), 'y': y_test}
# Generate datasets of varying sizes
train_datasets = {}
for percentage in dataset_percentages:
# Take a subset of the digits dataset based on the desired size
size = int(len(y_train)*percentage)
X, y = X_train[:size], y_train[:size]
train_datasets[percentage] = {'X': torch.FloatTensor(X), 'y': y}
def plot_code_distribution(model, test_loader):
"""
Plot the code distribution obtained by applying the trained RealNVP model to a test dataset.
Args:
- model (RealNVP): Trained RealNVP model.
- test_loader (DataLoader): DataLoader for the test dataset.
- num_samples (int): Number of samples to visualize.
Returns:
None (displays the plot).
"""
model.eval() # Set the model to evaluation mode
fig, axs = plt.subplots(2, 5, figsize=(20, 7))
with torch.no_grad():
# Concatenate multiple batches to obtain more samples
test_samples = torch.cat([batch for batch in test_loader], dim=0)
# Assuming your model has an `encode` method
code_samples = model.encode(test_samples)
# Convert PyTorch tensor to numpy array
code_np = code_samples.numpy()
dim_1 = 0
dim_2 = 1
for i in range(2):
for j in range(5):
# Scatter plot of code distribution
axs[i,j].scatter(code_np[:, dim_1], code_np[:, dim_2], label='Code Distribution', alpha=0.5)
axs[i,j].set_xlabel(f"Code Dimension {dim_1}")
axs[i,j].set_ylabel(f"Code Dimension {dim_2}")
axs[i,j].set_title(f'Code Distribution: {dim_2}')
dim_1 += 1
dim_2 += 1
plt.tight_layout()
plt.show()
def visualize_synthetic_data(synthetic_data, title=""):
"""
Scatter plot to visualize the original and synthetic data in 2D.
Args:
- synthetic_data (torch.Tensor): Synthetic data.
Returns:
- None: Displays the scatter plot.
"""
fig, axs = plt.subplots(2, 5, figsize=(20, 7))
# Ensure both original and synthetic data are converted to numpy arrays
with torch.no_grad():
# Convert PyTorch tensors to numpy arrays
synthetic_np = synthetic_data.numpy()
count = 0
for i in range(2):
for j in range(5):
axs[i,j].imshow(synthetic_np[count].reshape(8, 8), cmap='gray')
axs[i,j].set_title(f'Synthetic Image: {count}')
count += 1
# Scatter plot of original and synthetic data
fig.suptitle(title)
plt.show()
Now we test if our network works with the new dataset
input_size = 64
hidden_size = 200
blocks = 10
print_after=1
# initialize dataloader
dataset_percentage = 0.5
batch_size=32
data_considered=train_datasets[dataset_percentage]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=batch_size, shuffle=True)
val_loader= torch.utils.data.DataLoader(val_dataset['X'], batch_size=batch_size, shuffle=True)
# instantiate the model
model= RealNVP(input_size=input_size, hidden_size= hidden_size, blocks=blocks)
## train the model
train_losses, val_losses= train_and_evaluate(model, train_loader, val_loader, num_epochs=10, lr=0.001, print_after=1)
# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()
# Example usage:
plot_code_distribution(model=model, test_loader=val_loader)
plt.show()
### plot the synthetic data and the original data
synthetic_data=model.sample(num_samples=len(data_considered))
visualize_synthetic_data(synthetic_data)
plt.show()
Epoch 1/10, Training Loss: 19308.744353252907, Validation Loss: 308.0241241455078 Epoch 2/10, Training Loss: 151.0023258043372, Validation Loss: 260.6484127044678 Epoch 3/10, Training Loss: 129.43456749294114, Validation Loss: 241.6455866495768 Epoch 4/10, Training Loss: 115.87873409105384, Validation Loss: 231.48429171244302 Epoch 5/10, Training Loss: 106.71935371730639, Validation Loss: 228.04992612202963 Epoch 6/10, Training Loss: 100.71591966048531, Validation Loss: 232.55555311838785 Epoch 7/10, Training Loss: 95.37316413547681, Validation Loss: 227.84586747487387 Epoch 8/10, Training Loss: 91.04394149780273, Validation Loss: 230.0734135309855 Epoch 9/10, Training Loss: 87.94942756321119, Validation Loss: 229.5797259012858 Epoch 10/10, Training Loss: 83.55091410097869, Validation Loss: 240.51764233907065 Training complete
The code distribution looks quite resonable and gaussian distributed. The synthesized data is beginning to look like digits, but still quite noisy. Next we try to find the optimal hyperparameter.
learning_rates = [0.01,0.005,0.0005,0.0001,0.000005]
dataset_percentage = 1.0
# Create data loader for the fixed dataset size
data_considered = train_datasets[dataset_percentage]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset['X'], batch_size=32, shuffle=True)
for lr in learning_rates:
print(f"\nTraining for learning rate={lr}")
# Instantiate the model
model = RealNVP(input_size=input_size, hidden_size=hidden_size, blocks=blocks)
# Train the model
train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=30, lr=lr, print_after=1)
# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()
# Example usage:
plot_code_distribution(model=model, test_loader=val_loader)
plt.show()
### plot the synthetic data
synthetic_data=model.sample(num_samples=len(data_considered))
visualize_synthetic_data(synthetic_data)
Training for learning rate=0.01 Epoch 1/30, Training Loss: 8009.836741299099, Validation Loss: 288.0275885264079 Epoch 2/30, Training Loss: 76.08619503445095, Validation Loss: 256.4709021250407 Epoch 3/30, Training Loss: 68.7188730875651, Validation Loss: 241.0039800008138 Epoch 4/30, Training Loss: 65.80661010742188, Validation Loss: 232.09345626831055 Epoch 5/30, Training Loss: 63.3587641398112, Validation Loss: 228.34066772460938 Epoch 6/30, Training Loss: 61.78260345458985, Validation Loss: 224.1435661315918 Epoch 7/30, Training Loss: 61.00211885240343, Validation Loss: 221.20492553710938 Epoch 8/30, Training Loss: 59.872193315294055, Validation Loss: 222.45267963409424 Epoch 9/30, Training Loss: 58.49332021077474, Validation Loss: 216.29072443644205 Epoch 10/30, Training Loss: 57.88600684271918, Validation Loss: 214.59373696645102 Epoch 11/30, Training Loss: 57.48555611504449, Validation Loss: 212.99246056874594 Epoch 12/30, Training Loss: 57.16072599622938, Validation Loss: 217.7744598388672 Epoch 13/30, Training Loss: 56.830532752143014, Validation Loss: 210.1707499821981 Epoch 14/30, Training Loss: 56.36947267320421, Validation Loss: 211.24468994140625 Epoch 15/30, Training Loss: 55.42196782430013, Validation Loss: 210.97901821136475 Epoch 16/30, Training Loss: 55.75337677001953, Validation Loss: 208.09878635406494 Epoch 17/30, Training Loss: 55.02196782430013, Validation Loss: 210.40148131052652 Epoch 18/30, Training Loss: 54.59654473198785, Validation Loss: 205.72683238983154 Epoch 19/30, Training Loss: 54.583910878499346, Validation Loss: 208.6036809285482 Epoch 20/30, Training Loss: 54.372396850585936, Validation Loss: 206.71865940093994 Epoch 21/30, Training Loss: 53.65991999308268, Validation Loss: 204.12037054697672 Epoch 22/30, Training Loss: 53.65253762139214, Validation Loss: 208.5793809890747 Epoch 23/30, Training Loss: 53.74359368218316, Validation Loss: 211.7449827194214 Epoch 24/30, Training Loss: 53.684653727213544, Validation Loss: 207.8568785985311 Epoch 25/30, Training Loss: 53.2072132534451, Validation Loss: 207.94135189056396 Epoch 26/30, Training Loss: 53.41186709933811, Validation Loss: 209.1408322652181 Epoch 27/30, Training Loss: 52.97333077324761, Validation Loss: 204.48490047454834 Epoch 28/30, Training Loss: 52.960265689425995, Validation Loss: 203.9521099726359 Epoch 29/30, Training Loss: 52.629005432128906, Validation Loss: 206.9331143697103 Epoch 30/30, Training Loss: 53.477005343967015, Validation Loss: 205.88944085439047 Training complete
Training for learning rate=0.005 Epoch 1/30, Training Loss: 11817.525778198242, Validation Loss: 270.10508664449054 Epoch 2/30, Training Loss: 69.870982615153, Validation Loss: 232.40014362335205 Epoch 3/30, Training Loss: 62.34513236151801, Validation Loss: 220.9430373509725 Epoch 4/30, Training Loss: 58.622605387369795, Validation Loss: 211.92736530303955 Epoch 5/30, Training Loss: 55.77464650472005, Validation Loss: 207.57684199015299 Epoch 6/30, Training Loss: 53.64491526285807, Validation Loss: 202.77793534596762 Epoch 7/30, Training Loss: 51.990141211615665, Validation Loss: 199.2128356297811 Epoch 8/30, Training Loss: 50.460140482584634, Validation Loss: 198.17747497558594 Epoch 9/30, Training Loss: 50.364815860324434, Validation Loss: 198.5120356877645 Epoch 10/30, Training Loss: 49.0253177218967, Validation Loss: 196.53202692667642 Epoch 11/30, Training Loss: 47.818701171875, Validation Loss: 194.4502503077189 Epoch 12/30, Training Loss: 47.43562757703993, Validation Loss: 191.5679677327474 Epoch 13/30, Training Loss: 46.55343619452582, Validation Loss: 197.71495151519775 Epoch 14/30, Training Loss: 45.96852815416124, Validation Loss: 191.3148282368978 Epoch 15/30, Training Loss: 45.09313100179036, Validation Loss: 195.80774847666422 Epoch 16/30, Training Loss: 45.09204661051432, Validation Loss: 192.35662587483725 Epoch 17/30, Training Loss: 44.720115661621094, Validation Loss: 191.43877792358398 Epoch 18/30, Training Loss: 44.07941326565213, Validation Loss: 193.19898001352945 Epoch 19/30, Training Loss: 44.19743118286133, Validation Loss: 190.72268931070963 Epoch 20/30, Training Loss: 43.4016234503852, Validation Loss: 185.80727926890054 Epoch 21/30, Training Loss: 43.52100160386827, Validation Loss: 188.30844116210938 Epoch 22/30, Training Loss: 42.71001519097222, Validation Loss: 192.54564380645752 Epoch 23/30, Training Loss: 42.025776926676436, Validation Loss: 192.11276976267496 Epoch 24/30, Training Loss: 42.15523783365885, Validation Loss: 189.09472274780273 Epoch 25/30, Training Loss: 41.77534535725911, Validation Loss: 191.74198309580484 Epoch 26/30, Training Loss: 41.31961042616103, Validation Loss: 194.46389230092367 Epoch 27/30, Training Loss: 40.85063256157769, Validation Loss: 195.73884105682373 Epoch 28/30, Training Loss: 40.96485850016276, Validation Loss: 193.3294941584269 Epoch 29/30, Training Loss: 41.45541627671984, Validation Loss: 194.46637217203775 Epoch 30/30, Training Loss: 40.55387997097439, Validation Loss: 190.83805497487387 Training complete
Training for learning rate=0.0005 Epoch 1/30, Training Loss: 3421.306739637587, Validation Loss: 288.54217465718585 Epoch 2/30, Training Loss: 73.53326305813259, Validation Loss: 245.20367558797201 Epoch 3/30, Training Loss: 64.15791592068142, Validation Loss: 227.901712735494 Epoch 4/30, Training Loss: 58.57548853556315, Validation Loss: 217.85050106048584 Epoch 5/30, Training Loss: 55.11601825290256, Validation Loss: 212.453226407369 Epoch 6/30, Training Loss: 52.042591603597, Validation Loss: 209.29950936635336 Epoch 7/30, Training Loss: 49.50460010104709, Validation Loss: 204.23781808217367 Epoch 8/30, Training Loss: 47.800843217637805, Validation Loss: 207.67214838663736 Epoch 9/30, Training Loss: 46.12979066636827, Validation Loss: 207.23672898610434 Epoch 10/30, Training Loss: 44.459814368353946, Validation Loss: 202.5842374165853 Epoch 11/30, Training Loss: 43.03138614230686, Validation Loss: 207.65548356374106 Epoch 12/30, Training Loss: 41.74035085042318, Validation Loss: 208.69722620646158 Epoch 13/30, Training Loss: 40.926596577962236, Validation Loss: 213.01070054372153 Epoch 14/30, Training Loss: 39.39276572333442, Validation Loss: 210.76309076944986 Epoch 15/30, Training Loss: 38.11468183729384, Validation Loss: 221.10226694742838 Epoch 16/30, Training Loss: 37.34300376044379, Validation Loss: 216.53514099121094 Epoch 17/30, Training Loss: 36.546615261501735, Validation Loss: 226.68182563781738 Epoch 18/30, Training Loss: 36.00442470974392, Validation Loss: 228.22372436523438 Epoch 19/30, Training Loss: 34.968190087212456, Validation Loss: 227.42056020100912 Epoch 20/30, Training Loss: 34.6193118625217, Validation Loss: 221.6729122797648 Epoch 21/30, Training Loss: 34.06460901896159, Validation Loss: 240.29690742492676 Epoch 22/30, Training Loss: 33.063655853271484, Validation Loss: 242.19485092163086 Epoch 23/30, Training Loss: 32.59690784878201, Validation Loss: 249.20933310190836 Epoch 24/30, Training Loss: 31.83392054239909, Validation Loss: 248.67272027333578 Epoch 25/30, Training Loss: 31.582435353597006, Validation Loss: 256.09630997975665 Epoch 26/30, Training Loss: 30.775786675347224, Validation Loss: 266.1253210703532 Epoch 27/30, Training Loss: 30.550120205349394, Validation Loss: 252.3609733581543 Epoch 28/30, Training Loss: 30.08771603902181, Validation Loss: 256.7566725413005 Epoch 29/30, Training Loss: 29.445677778455945, Validation Loss: 263.6085141499837 Epoch 30/30, Training Loss: 28.81046553717719, Validation Loss: 275.15847905476886 Training complete
Training for learning rate=0.0001 Epoch 1/30, Training Loss: 50113.41761644151, Validation Loss: 436.48069826761883 Epoch 2/30, Training Loss: 101.97941606309679, Validation Loss: 320.2299327850342 Epoch 3/30, Training Loss: 85.60570593939887, Validation Loss: 290.40064366658527 Epoch 4/30, Training Loss: 78.33708835177951, Validation Loss: 272.7111422220866 Epoch 5/30, Training Loss: 73.4560770670573, Validation Loss: 260.9111525217692 Epoch 6/30, Training Loss: 69.79727257622613, Validation Loss: 252.03315226236978 Epoch 7/30, Training Loss: 66.88374820285374, Validation Loss: 245.57611910502115 Epoch 8/30, Training Loss: 64.33567445543078, Validation Loss: 239.41691970825195 Epoch 9/30, Training Loss: 62.18655675252278, Validation Loss: 234.46007283528647 Epoch 10/30, Training Loss: 60.42937969631619, Validation Loss: 231.49946689605713 Epoch 11/30, Training Loss: 58.72264811197917, Validation Loss: 229.10130310058594 Epoch 12/30, Training Loss: 57.17389127943251, Validation Loss: 226.6750087738037 Epoch 13/30, Training Loss: 55.773566012912326, Validation Loss: 225.00317096710205 Epoch 14/30, Training Loss: 54.53753424750434, Validation Loss: 222.3171361287435 Epoch 15/30, Training Loss: 53.36913740370009, Validation Loss: 222.87827587127686 Epoch 16/30, Training Loss: 52.27295362684462, Validation Loss: 223.05616505940756 Epoch 17/30, Training Loss: 51.29743626912435, Validation Loss: 222.3541399637858 Epoch 18/30, Training Loss: 50.20012520684136, Validation Loss: 219.63851642608643 Epoch 19/30, Training Loss: 49.292029995388454, Validation Loss: 223.75088246663412 Epoch 20/30, Training Loss: 48.423710378011066, Validation Loss: 225.47177600860596 Epoch 21/30, Training Loss: 47.65950080023872, Validation Loss: 224.14251295725504 Epoch 22/30, Training Loss: 46.83405244615343, Validation Loss: 224.9314339955648 Epoch 23/30, Training Loss: 46.0850218878852, Validation Loss: 228.0197811126709 Epoch 24/30, Training Loss: 45.30505820380317, Validation Loss: 230.37444273630777 Epoch 25/30, Training Loss: 44.65458026462131, Validation Loss: 232.04700247446695 Epoch 26/30, Training Loss: 43.974732123480905, Validation Loss: 237.41689682006836 Epoch 27/30, Training Loss: 43.24482091267904, Validation Loss: 234.12914307912192 Epoch 28/30, Training Loss: 42.66191660563151, Validation Loss: 237.63699022928873 Epoch 29/30, Training Loss: 41.99319805569119, Validation Loss: 241.9967892964681 Epoch 30/30, Training Loss: 41.384490712483725, Validation Loss: 248.90463892618814 Training complete
Training for learning rate=5e-06 Epoch 1/30, Training Loss: 209076.60121527777, Validation Loss: 405062.1399739583 Epoch 2/30, Training Loss: 69862.05911458333, Validation Loss: 127279.57747395833 Epoch 3/30, Training Loss: 22368.46755642361, Validation Loss: 39132.588623046875 Epoch 4/30, Training Loss: 7067.427842881944, Validation Loss: 13500.332214355469 Epoch 5/30, Training Loss: 2560.1897352430556, Validation Loss: 5463.914642333984 Epoch 6/30, Training Loss: 1113.8959011501736, Validation Loss: 2742.546442667643 Epoch 7/30, Training Loss: 608.1996073404948, Validation Loss: 1645.3327814737956 Epoch 8/30, Training Loss: 386.06957058376736, Validation Loss: 1121.6913019816081 Epoch 9/30, Training Loss: 277.5015930175781, Validation Loss: 843.8144543965658 Epoch 10/30, Training Loss: 216.93083055284288, Validation Loss: 683.6320292154948 Epoch 11/30, Training Loss: 180.46121758355034, Validation Loss: 585.8069547017416 Epoch 12/30, Training Loss: 157.36595323350696, Validation Loss: 522.2299512227377 Epoch 13/30, Training Loss: 142.15333811442056, Validation Loss: 478.9925130208333 Epoch 14/30, Training Loss: 131.64742160373265, Validation Loss: 448.8523184458415 Epoch 15/30, Training Loss: 124.23250885009766, Validation Loss: 426.7560176849365 Epoch 16/30, Training Loss: 118.68551194932726, Validation Loss: 410.0090611775716 Epoch 17/30, Training Loss: 114.39310455322266, Validation Loss: 396.76917203267413 Epoch 18/30, Training Loss: 110.9672132703993, Validation Loss: 386.13282267252606 Epoch 19/30, Training Loss: 108.15684068467883, Validation Loss: 377.21822293599445 Epoch 20/30, Training Loss: 105.78033989800348, Validation Loss: 369.6262327829997 Epoch 21/30, Training Loss: 103.73234710693359, Validation Loss: 363.0723139444987 Epoch 22/30, Training Loss: 101.93745642768012, Validation Loss: 357.2292785644531 Epoch 23/30, Training Loss: 100.34205101860894, Validation Loss: 352.077735265096 Epoch 24/30, Training Loss: 98.9117190890842, Validation Loss: 347.46399943033856 Epoch 25/30, Training Loss: 97.61323394775391, Validation Loss: 343.2505931854248 Epoch 26/30, Training Loss: 96.42156558566623, Validation Loss: 339.3616517384847 Epoch 27/30, Training Loss: 95.32410888671875, Validation Loss: 335.8017037709554 Epoch 28/30, Training Loss: 94.30238511827257, Validation Loss: 332.52455139160156 Epoch 29/30, Training Loss: 93.36457349989149, Validation Loss: 329.5176232655843 Epoch 30/30, Training Loss: 92.4872580634223, Validation Loss: 326.7073853810628 Training complete
We see, that the best results for the validation loss and the generated images are with a learning rate of lr= 0.005. Since we already see, that the algorithms starts to overfit at around 19 epochs, we wont need to analyze the epochs count and the dataset size: We already found the best epoch count and due to the occuring overfitting it would not make sense to further test smaller dataset sizes
hidden_sizes = [100, 200, 400]
blocks = [2, 5, 10]
input_size = 64
dataset_percentage = 1.0
# Create data loader for the fixed dataset size
data_considered = train_datasets[dataset_percentage]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset['X'], batch_size=32, shuffle=True)
for hidden_size in hidden_sizes:
for block in blocks:
print(f"\nTraining for hidden_size={hidden_size}, blocks = {block}")
# Instantiate the model
model = RealNVP(input_size=input_size, hidden_size=hidden_size, blocks=block)
# Train the model
train_losses, val_losses = train_and_evaluate(model, train_loader, val_loader, num_epochs=30, lr=0.005, print_after=1)
# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()
# Example usage:
plot_code_distribution(model=model, test_loader=val_loader)
plt.show()
### plot the synthetic data
synthetic_data=model.sample(num_samples=len(data_considered))
visualize_synthetic_data(synthetic_data)
Training for hidden_size=100, blocks = 2 Epoch 1/30, Training Loss: 294.7230936686198, Validation Loss: 571.6155484517416 Epoch 2/30, Training Loss: 141.9142296685113, Validation Loss: 468.15574010213214 Epoch 3/30, Training Loss: 124.05112999810113, Validation Loss: 420.20416831970215 Epoch 4/30, Training Loss: 114.67127126057943, Validation Loss: 405.47664070129395 Epoch 5/30, Training Loss: 112.06046108669705, Validation Loss: 401.04412778218585 Epoch 6/30, Training Loss: 109.06603291829427, Validation Loss: 385.57850710550946 Epoch 7/30, Training Loss: 104.84836205376519, Validation Loss: 383.4814968109131 Epoch 8/30, Training Loss: 102.67653469509548, Validation Loss: 399.0242977142334 Epoch 9/30, Training Loss: 102.31197408040364, Validation Loss: 360.4715805053711 Epoch 10/30, Training Loss: 99.84381306966146, Validation Loss: 357.0197696685791 Epoch 11/30, Training Loss: 99.02914411756727, Validation Loss: 355.18566576639813 Epoch 12/30, Training Loss: 97.31855146620009, Validation Loss: 351.6528517405192 Epoch 13/30, Training Loss: 96.65306871202257, Validation Loss: 358.52728335062665 Epoch 14/30, Training Loss: 95.8873289320204, Validation Loss: 358.1687469482422 Epoch 15/30, Training Loss: 94.39329071044922, Validation Loss: 355.81766446431476 Epoch 16/30, Training Loss: 94.19340837266711, Validation Loss: 337.2964045206706 Epoch 17/30, Training Loss: 92.56639607747395, Validation Loss: 338.4132130940755 Epoch 18/30, Training Loss: 92.66181182861328, Validation Loss: 327.18345133463544 Epoch 19/30, Training Loss: 92.0412845187717, Validation Loss: 331.57775179545087 Epoch 20/30, Training Loss: 90.90107608371311, Validation Loss: 334.78138478597003 Epoch 21/30, Training Loss: 89.76652086046008, Validation Loss: 332.50303332010907 Epoch 22/30, Training Loss: 89.18487158881294, Validation Loss: 332.22913614908856 Epoch 23/30, Training Loss: 89.43389723036024, Validation Loss: 327.5421390533447 Epoch 24/30, Training Loss: 88.90379621717665, Validation Loss: 328.9135939280192 Epoch 25/30, Training Loss: 89.27030453152126, Validation Loss: 313.00710614522296 Epoch 26/30, Training Loss: 88.07124938964844, Validation Loss: 307.8485215504964 Epoch 27/30, Training Loss: 86.89611290825738, Validation Loss: 308.5803909301758 Epoch 28/30, Training Loss: 85.90565677218967, Validation Loss: 328.0444863637288 Epoch 29/30, Training Loss: 88.49742211235895, Validation Loss: 319.33860270182294 Epoch 30/30, Training Loss: 88.56594645182291, Validation Loss: 309.60890515645343 Training complete
Training for hidden_size=100, blocks = 5 Epoch 1/30, Training Loss: 265.01097886827256, Validation Loss: 301.99928347269696 Epoch 2/30, Training Loss: 77.97027062310113, Validation Loss: 256.44384956359863 Epoch 3/30, Training Loss: 69.68282352023654, Validation Loss: 241.7223745981852 Epoch 4/30, Training Loss: 65.50670505099826, Validation Loss: 230.7071574529012 Epoch 5/30, Training Loss: 62.328228335910374, Validation Loss: 224.24456691741943 Epoch 6/30, Training Loss: 60.30294715033637, Validation Loss: 218.9724578857422 Epoch 7/30, Training Loss: 58.39633288913303, Validation Loss: 218.47990926106772 Epoch 8/30, Training Loss: 56.874779001871744, Validation Loss: 212.0576960245768 Epoch 9/30, Training Loss: 56.036871846516924, Validation Loss: 210.8950433731079 Epoch 10/30, Training Loss: 55.106080034044055, Validation Loss: 207.61690107981363 Epoch 11/30, Training Loss: 54.76850458780925, Validation Loss: 213.44031524658203 Epoch 12/30, Training Loss: 53.71915673149957, Validation Loss: 210.15140438079834 Epoch 13/30, Training Loss: 52.99325018988715, Validation Loss: 203.27078851064047 Epoch 14/30, Training Loss: 52.25531472100152, Validation Loss: 202.80283133188883 Epoch 15/30, Training Loss: 51.514439307318796, Validation Loss: 200.92092609405518 Epoch 16/30, Training Loss: 51.71739222208659, Validation Loss: 199.66056442260742 Epoch 17/30, Training Loss: 50.8276117960612, Validation Loss: 200.95742066701254 Epoch 18/30, Training Loss: 50.33599675496419, Validation Loss: 203.53884760538736 Epoch 19/30, Training Loss: 49.800680796305336, Validation Loss: 203.55471007029215 Epoch 20/30, Training Loss: 49.49485787285699, Validation Loss: 200.13616148630777 Epoch 21/30, Training Loss: 48.99873826768663, Validation Loss: 198.87054856618246 Epoch 22/30, Training Loss: 48.645048183865015, Validation Loss: 198.45662053426108 Epoch 23/30, Training Loss: 48.5343746609158, Validation Loss: 194.56494808197021 Epoch 24/30, Training Loss: 47.689731174045136, Validation Loss: 197.5397570927938 Epoch 25/30, Training Loss: 48.05319391886393, Validation Loss: 197.968976020813 Epoch 26/30, Training Loss: 47.807348039415146, Validation Loss: 198.26151434580484 Epoch 27/30, Training Loss: 47.789947001139325, Validation Loss: 195.22731018066406 Epoch 28/30, Training Loss: 47.258505249023436, Validation Loss: 198.5635643005371 Epoch 29/30, Training Loss: 47.24866765340169, Validation Loss: 197.3188044230143 Epoch 30/30, Training Loss: 47.136260477701825, Validation Loss: 197.86065769195557 Training complete
Training for hidden_size=100, blocks = 10 Epoch 1/30, Training Loss: 9436.184055074056, Validation Loss: 263.33286539713544 Epoch 2/30, Training Loss: 67.61033986409505, Validation Loss: 228.6816488901774 Epoch 3/30, Training Loss: 60.80630026923286, Validation Loss: 215.70353666941324 Epoch 4/30, Training Loss: 57.06355929904514, Validation Loss: 208.36373488108316 Epoch 5/30, Training Loss: 54.60490230984158, Validation Loss: 203.7723045349121 Epoch 6/30, Training Loss: 52.81721649169922, Validation Loss: 198.82628377278647 Epoch 7/30, Training Loss: 51.284493425157336, Validation Loss: 196.23615233103433 Epoch 8/30, Training Loss: 50.3772331237793, Validation Loss: 195.4634517033895 Epoch 9/30, Training Loss: 49.036874050564236, Validation Loss: 190.6725641886393 Epoch 10/30, Training Loss: 47.81766933865018, Validation Loss: 191.61908117930093 Epoch 11/30, Training Loss: 47.524300469292534, Validation Loss: 191.0384251276652 Epoch 12/30, Training Loss: 46.822449662950305, Validation Loss: 191.2440538406372 Epoch 13/30, Training Loss: 45.9994388156467, Validation Loss: 190.5266834894816 Epoch 14/30, Training Loss: 45.582711622450084, Validation Loss: 188.55037212371826 Epoch 15/30, Training Loss: 44.89894324408637, Validation Loss: 187.62296676635742 Epoch 16/30, Training Loss: 44.38909420437283, Validation Loss: 190.2049217224121 Epoch 17/30, Training Loss: 44.68807390001085, Validation Loss: 189.20528157552084 Epoch 18/30, Training Loss: 43.929205830891924, Validation Loss: 189.91830094655356 Epoch 19/30, Training Loss: 43.186015065511064, Validation Loss: 188.0519240697225 Epoch 20/30, Training Loss: 42.362147352430554, Validation Loss: 186.8313970565796 Epoch 21/30, Training Loss: 42.51196111043294, Validation Loss: 186.515017191569 Epoch 22/30, Training Loss: 41.920947265625, Validation Loss: 189.6102253595988 Epoch 23/30, Training Loss: 41.618788146972655, Validation Loss: 187.03820164998373 Epoch 24/30, Training Loss: 41.889136505126956, Validation Loss: 183.9749552408854 Epoch 25/30, Training Loss: 41.11484959920247, Validation Loss: 187.04142888387045 Epoch 26/30, Training Loss: 41.106873491075305, Validation Loss: 185.5558303197225 Epoch 27/30, Training Loss: 40.58207711113824, Validation Loss: 190.92296314239502 Epoch 28/30, Training Loss: 40.81068674723307, Validation Loss: 193.93083826700845 Epoch 29/30, Training Loss: 40.431168450249565, Validation Loss: 188.75267124176025 Epoch 30/30, Training Loss: 39.85402603149414, Validation Loss: 184.7258857091268 Training complete
Training for hidden_size=200, blocks = 2 Epoch 1/30, Training Loss: 216.53725297715928, Validation Loss: 444.19052505493164 Epoch 2/30, Training Loss: 113.37141401502821, Validation Loss: 364.4847469329834 Epoch 3/30, Training Loss: 100.15595075819228, Validation Loss: 345.93261528015137 Epoch 4/30, Training Loss: 94.53593071831597, Validation Loss: 331.55730120340985 Epoch 5/30, Training Loss: 92.21392093234591, Validation Loss: 324.0475031534831 Epoch 6/30, Training Loss: 88.54855431450738, Validation Loss: 308.84101994832355 Epoch 7/30, Training Loss: 86.4598612467448, Validation Loss: 309.0832977294922 Epoch 8/30, Training Loss: 84.3639392428928, Validation Loss: 306.3279215494792 Epoch 9/30, Training Loss: 83.28580000135634, Validation Loss: 311.41941324869794 Epoch 10/30, Training Loss: 81.50191514756945, Validation Loss: 296.9877471923828 Epoch 11/30, Training Loss: 79.26307542588975, Validation Loss: 278.895850499471 Epoch 12/30, Training Loss: 78.94070926242405, Validation Loss: 271.9144166310628 Epoch 13/30, Training Loss: 78.2772957695855, Validation Loss: 293.2340513865153 Epoch 14/30, Training Loss: 78.42385457356771, Validation Loss: 287.2551898956299 Epoch 15/30, Training Loss: 76.45130123562284, Validation Loss: 274.488676071167 Epoch 16/30, Training Loss: 76.47469346788195, Validation Loss: 276.59137535095215 Epoch 17/30, Training Loss: 75.16443176269532, Validation Loss: 275.96004931132 Epoch 18/30, Training Loss: 76.47992943657769, Validation Loss: 270.53071784973145 Epoch 19/30, Training Loss: 74.54247843424479, Validation Loss: 273.3419183095296 Epoch 20/30, Training Loss: 74.45988430447049, Validation Loss: 276.253563563029 Epoch 21/30, Training Loss: 74.91201510959202, Validation Loss: 280.04852358500165 Epoch 22/30, Training Loss: 74.06419660780165, Validation Loss: 261.0175202687581 Epoch 23/30, Training Loss: 73.2760986328125, Validation Loss: 266.8240426381429 Epoch 24/30, Training Loss: 73.63587934705946, Validation Loss: 268.8825174967448 Epoch 25/30, Training Loss: 72.39171634250216, Validation Loss: 256.8872324625651 Epoch 26/30, Training Loss: 71.26632826063368, Validation Loss: 260.2430502573649 Epoch 27/30, Training Loss: 72.33309190538195, Validation Loss: 263.19954744974774 Epoch 28/30, Training Loss: 71.85370229085287, Validation Loss: 264.2111365000407 Epoch 29/30, Training Loss: 70.93008236355251, Validation Loss: 263.0232054392497 Epoch 30/30, Training Loss: 71.94356502956815, Validation Loss: 258.7958990732829 Training complete
Training for hidden_size=200, blocks = 5 Epoch 1/30, Training Loss: 309.551146613227, Validation Loss: 296.65241622924805 Epoch 2/30, Training Loss: 77.70625271267362, Validation Loss: 259.46020062764484 Epoch 3/30, Training Loss: 69.53970616658529, Validation Loss: 240.80973148345947 Epoch 4/30, Training Loss: 65.07096065945096, Validation Loss: 229.57966740926108 Epoch 5/30, Training Loss: 62.08271713256836, Validation Loss: 223.40517075856528 Epoch 6/30, Training Loss: 59.97387279934353, Validation Loss: 221.47677834828696 Epoch 7/30, Training Loss: 58.47309239705404, Validation Loss: 217.21256033579508 Epoch 8/30, Training Loss: 56.85402018229167, Validation Loss: 214.76035340627035 Epoch 9/30, Training Loss: 55.775561014811196, Validation Loss: 210.1625213623047 Epoch 10/30, Training Loss: 54.69895239935981, Validation Loss: 212.7543576558431 Epoch 11/30, Training Loss: 54.13418511284722, Validation Loss: 206.8346061706543 Epoch 12/30, Training Loss: 53.188574133978946, Validation Loss: 206.58446153004965 Epoch 13/30, Training Loss: 52.80375417073568, Validation Loss: 209.3851442337036 Epoch 14/30, Training Loss: 52.33621393839518, Validation Loss: 202.7700351079305 Epoch 15/30, Training Loss: 51.24921044243707, Validation Loss: 200.8417387008667 Epoch 16/30, Training Loss: 50.287712012396916, Validation Loss: 201.89845403035483 Epoch 17/30, Training Loss: 49.85312194824219, Validation Loss: 202.5632349650065 Epoch 18/30, Training Loss: 49.33689439561632, Validation Loss: 199.69777806599936 Epoch 19/30, Training Loss: 49.72008887396918, Validation Loss: 204.82231680552164 Epoch 20/30, Training Loss: 49.121662309434676, Validation Loss: 202.7250280380249 Epoch 21/30, Training Loss: 49.025482940673825, Validation Loss: 206.9731995264689 Epoch 22/30, Training Loss: 48.66969256930881, Validation Loss: 202.8100382486979 Epoch 23/30, Training Loss: 47.902609168158634, Validation Loss: 202.15858459472656 Epoch 24/30, Training Loss: 47.533621554904514, Validation Loss: 199.37864557902017 Epoch 25/30, Training Loss: 46.99391530354818, Validation Loss: 201.3996795018514 Epoch 26/30, Training Loss: 47.13845409817166, Validation Loss: 200.77649021148682 Epoch 27/30, Training Loss: 47.37102983262804, Validation Loss: 199.04171403249106 Epoch 28/30, Training Loss: 46.85574535793728, Validation Loss: 204.1581137975057 Epoch 29/30, Training Loss: 46.16206520928277, Validation Loss: 201.75563176472983 Epoch 30/30, Training Loss: 46.21358116997613, Validation Loss: 202.40555699666342 Training complete
Training for hidden_size=200, blocks = 10 Epoch 1/30, Training Loss: 6492.414649454752, Validation Loss: 282.70446332295734 Epoch 2/30, Training Loss: 73.40664850870768, Validation Loss: 246.55630683898926 Epoch 3/30, Training Loss: 64.3135969373915, Validation Loss: 226.65690644582114 Epoch 4/30, Training Loss: 60.09317186143663, Validation Loss: 215.6777229309082 Epoch 5/30, Training Loss: 57.14421852959527, Validation Loss: 210.23212718963623 Epoch 6/30, Training Loss: 55.430495198567705, Validation Loss: 204.4380890528361 Epoch 7/30, Training Loss: 53.62228147718641, Validation Loss: 208.22985331217447 Epoch 8/30, Training Loss: 52.40490570068359, Validation Loss: 201.35067780812582 Epoch 9/30, Training Loss: 51.76244778103299, Validation Loss: 199.5050137837728 Epoch 10/30, Training Loss: 50.03342141045464, Validation Loss: 195.68669923146567 Epoch 11/30, Training Loss: 49.121512010362416, Validation Loss: 197.29208914438883 Epoch 12/30, Training Loss: 48.897773827446834, Validation Loss: 196.3861207962036 Epoch 13/30, Training Loss: 48.12302941216363, Validation Loss: 196.06686433156332 Epoch 14/30, Training Loss: 47.333321211073134, Validation Loss: 191.7859255472819 Epoch 15/30, Training Loss: 47.04701470269097, Validation Loss: 194.23813724517822 Epoch 16/30, Training Loss: 46.205683898925784, Validation Loss: 193.08368174235025 Epoch 17/30, Training Loss: 45.86734517415365, Validation Loss: 194.6445223490397 Epoch 18/30, Training Loss: 45.691319190131296, Validation Loss: 190.83626715342203 Epoch 19/30, Training Loss: 44.578766123453775, Validation Loss: 198.11820379892984 Epoch 20/30, Training Loss: 44.48360688951281, Validation Loss: 196.22168699900308 Epoch 21/30, Training Loss: 44.72951049804688, Validation Loss: 192.11080837249756 Epoch 22/30, Training Loss: 44.1974353366428, Validation Loss: 193.88198947906494 Epoch 23/30, Training Loss: 44.18268771701389, Validation Loss: 196.84188238779703 Epoch 24/30, Training Loss: 43.73227649264865, Validation Loss: 193.48014958699545 Epoch 25/30, Training Loss: 43.711177571614584, Validation Loss: 194.7969299952189 Epoch 26/30, Training Loss: 42.93069229125977, Validation Loss: 191.0793244043986 Epoch 27/30, Training Loss: 42.515458255343965, Validation Loss: 191.73682816823325 Epoch 28/30, Training Loss: 43.338741048177084, Validation Loss: 193.38479391733804 Epoch 29/30, Training Loss: 42.78940794203017, Validation Loss: 194.90061601003012 Epoch 30/30, Training Loss: 42.59330757988824, Validation Loss: 195.17180665334067 Training complete
Training for hidden_size=400, blocks = 2 Epoch 1/30, Training Loss: 307.47569749620226, Validation Loss: 660.282797495524 Epoch 2/30, Training Loss: 180.9752634684245, Validation Loss: 812.9520988464355 Epoch 3/30, Training Loss: 173.47369859483507, Validation Loss: 569.1646308898926 Epoch 4/30, Training Loss: 159.17198621961805, Validation Loss: 565.31818262736 Epoch 5/30, Training Loss: 158.90091518825955, Validation Loss: 551.1031494140625 Epoch 6/30, Training Loss: 159.60764600965712, Validation Loss: 574.1276728312174 Epoch 7/30, Training Loss: 145.9571780734592, Validation Loss: 495.2603123982747 Epoch 8/30, Training Loss: 142.371486070421, Validation Loss: 491.8583634694417 Epoch 9/30, Training Loss: 141.50491773817274, Validation Loss: 498.96835072835285 Epoch 10/30, Training Loss: 143.82400377061632, Validation Loss: 492.4005317687988 Epoch 11/30, Training Loss: 146.19541456434462, Validation Loss: 574.3242225646973 Epoch 12/30, Training Loss: 145.28963758680555, Validation Loss: 528.7288983662924 Epoch 13/30, Training Loss: 136.63750644259983, Validation Loss: 515.2850685119629 Epoch 14/30, Training Loss: 142.1431399875217, Validation Loss: 489.8877480824788 Epoch 15/30, Training Loss: 142.1147196451823, Validation Loss: 474.4272352854411 Epoch 16/30, Training Loss: 140.7830559624566, Validation Loss: 495.7752113342285 Epoch 17/30, Training Loss: 147.55852644178603, Validation Loss: 470.71923510233563 Epoch 18/30, Training Loss: 141.1549747043186, Validation Loss: 539.7263113657633 Epoch 19/30, Training Loss: 135.7890879313151, Validation Loss: 485.22253545125324 Epoch 20/30, Training Loss: 135.64784579806857, Validation Loss: 490.24193954467773 Epoch 21/30, Training Loss: 130.6491214328342, Validation Loss: 452.17605781555176 Epoch 22/30, Training Loss: 126.53772633870443, Validation Loss: 451.6606750488281 Epoch 23/30, Training Loss: 127.67455647786458, Validation Loss: 449.65537707010907 Epoch 24/30, Training Loss: 127.42182057698568, Validation Loss: 488.5056915283203 Epoch 25/30, Training Loss: 127.36804436577691, Validation Loss: 496.6194330851237 Epoch 26/30, Training Loss: 133.38234303792316, Validation Loss: 489.46679941813153 Epoch 27/30, Training Loss: 130.7207258436415, Validation Loss: 513.3341700236002 Epoch 28/30, Training Loss: 127.46561448838976, Validation Loss: 457.0716412862142 Epoch 29/30, Training Loss: 125.12691226535374, Validation Loss: 442.3211104075114 Epoch 30/30, Training Loss: 124.04728088378906, Validation Loss: 441.13921610514325 Training complete
Training for hidden_size=400, blocks = 5 Epoch 1/30, Training Loss: 316.47518547905815, Validation Loss: 309.94998868306476 Epoch 2/30, Training Loss: 82.13634084065755, Validation Loss: 275.5159142812093 Epoch 3/30, Training Loss: 74.84136166042752, Validation Loss: 262.6661961873372 Epoch 4/30, Training Loss: 70.7937240600586, Validation Loss: 249.78180122375488 Epoch 5/30, Training Loss: 68.1334605746799, Validation Loss: 245.63199742635092 Epoch 6/30, Training Loss: 65.83484369913737, Validation Loss: 234.56178633371988 Epoch 7/30, Training Loss: 63.888429175482855, Validation Loss: 240.55420271555582 Epoch 8/30, Training Loss: 62.737112765842014, Validation Loss: 235.59585825602213 Epoch 9/30, Training Loss: 62.29926893446181, Validation Loss: 233.10778681437174 Epoch 10/30, Training Loss: 61.70500666300456, Validation Loss: 228.81559340159097 Epoch 11/30, Training Loss: 60.07350964016385, Validation Loss: 227.86209615071616 Epoch 12/30, Training Loss: 59.456063503689236, Validation Loss: 226.61377970377603 Epoch 13/30, Training Loss: 58.891603766547306, Validation Loss: 227.92042700449625 Epoch 14/30, Training Loss: 58.36316468980577, Validation Loss: 224.34271812438965 Epoch 15/30, Training Loss: 57.4016359117296, Validation Loss: 224.10171476999918 Epoch 16/30, Training Loss: 58.129408518473305, Validation Loss: 221.11990706125894 Epoch 17/30, Training Loss: 57.24721306694879, Validation Loss: 227.1364345550537 Epoch 18/30, Training Loss: 56.437878672281904, Validation Loss: 224.80929819742838 Epoch 19/30, Training Loss: 56.5604724460178, Validation Loss: 220.81223583221436 Epoch 20/30, Training Loss: 54.89689534505208, Validation Loss: 220.16179784138998 Epoch 21/30, Training Loss: 54.49175864325629, Validation Loss: 223.90884653727213 Epoch 22/30, Training Loss: 54.91522538926866, Validation Loss: 221.36671447753906 Epoch 23/30, Training Loss: 54.64884719848633, Validation Loss: 224.3624588648478 Epoch 24/30, Training Loss: 54.53467610677083, Validation Loss: 219.23523680369058 Epoch 25/30, Training Loss: 54.39541456434462, Validation Loss: 225.01481914520264 Epoch 26/30, Training Loss: 53.75576400756836, Validation Loss: 224.30957667032877 Epoch 27/30, Training Loss: 53.983888668484155, Validation Loss: 222.11862341562906 Epoch 28/30, Training Loss: 53.86955354478624, Validation Loss: 225.67633787790933 Epoch 29/30, Training Loss: 53.31535432603624, Validation Loss: 224.4200760523478 Epoch 30/30, Training Loss: 52.715677218967016, Validation Loss: 223.37881247202554 Training complete
Training for hidden_size=400, blocks = 10 Epoch 1/30, Training Loss: 13539.159749179416, Validation Loss: 342.41082700093585 Epoch 2/30, Training Loss: 86.58183034261067, Validation Loss: 278.1364510854085 Epoch 3/30, Training Loss: 75.3851586235894, Validation Loss: 259.4295717875163 Epoch 4/30, Training Loss: 70.06561041937934, Validation Loss: 287.4711570739746 Epoch 5/30, Training Loss: 67.08853412204319, Validation Loss: 243.05325508117676 Epoch 6/30, Training Loss: 64.71018481784397, Validation Loss: 239.3382879892985 Epoch 7/30, Training Loss: 63.28148990207248, Validation Loss: 233.9927625656128 Epoch 8/30, Training Loss: 61.262958187527126, Validation Loss: 225.98743216196695 Epoch 9/30, Training Loss: 60.458206515842015, Validation Loss: 230.18698183695474 Epoch 10/30, Training Loss: 58.9208381652832, Validation Loss: 224.2623545328776 Epoch 11/30, Training Loss: 58.24448852539062, Validation Loss: 221.45679410298666 Epoch 12/30, Training Loss: 56.951568094889325, Validation Loss: 219.85390186309814 Epoch 13/30, Training Loss: 56.31669930352105, Validation Loss: 219.88257439931235 Epoch 14/30, Training Loss: 55.53523585001628, Validation Loss: 222.5670598347982 Epoch 15/30, Training Loss: 54.74429762098524, Validation Loss: 213.9870694478353 Epoch 16/30, Training Loss: 54.23894585503472, Validation Loss: 224.2745631535848 Epoch 17/30, Training Loss: 53.661950005425346, Validation Loss: 217.9786860148112 Epoch 18/30, Training Loss: 54.17339884440104, Validation Loss: 217.71878623962402 Epoch 19/30, Training Loss: 53.162393188476564, Validation Loss: 213.33333587646484 Epoch 20/30, Training Loss: 52.72771631876628, Validation Loss: 216.45318190256754 Epoch 21/30, Training Loss: 53.03551466200087, Validation Loss: 219.1963389714559 Epoch 22/30, Training Loss: 53.08130594889323, Validation Loss: 221.36446475982666 Epoch 23/30, Training Loss: 51.74025141398112, Validation Loss: 217.7514041264852 Epoch 24/30, Training Loss: 51.72227588229709, Validation Loss: 220.6808811823527 Epoch 25/30, Training Loss: 51.49509489271376, Validation Loss: 213.47276719411215 Epoch 26/30, Training Loss: 51.085971408420136, Validation Loss: 215.52674261728922 Epoch 27/30, Training Loss: 50.79202465481228, Validation Loss: 212.3214489618937 Epoch 28/30, Training Loss: 50.732866838243275, Validation Loss: 212.61949157714844 Epoch 29/30, Training Loss: 50.685603586832684, Validation Loss: 212.96100012461343 Epoch 30/30, Training Loss: 50.355876753065324, Validation Loss: 218.3713518778483 Training complete
We see, the best network structure appears to be hidden size of 100 and 10 blocks
class RealNVP_bottleneck(nn.Module):
def __init__(self, input_size, hidden_size, blocks, k):
super(RealNVP_bottleneck, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.blocks = blocks
self.k = k
# List of coupling layers
self.coupling_layers = nn.ModuleList([
CouplingLayer(input_size, hidden_size) for _ in range(blocks)
])
# List to store orthonormal matrices
self.orthonormal_matrices = [self._get_orthonormal_matrix(input_size) for _ in range(blocks)]
# List to store scaling_before_exp for each block
self.scaling_before_exp_list = []
def _get_orthonormal_matrix(self, size):
# Function to generate a random orthonormal matrix
w = torch.randn(size, size)
q, _ = torch.linalg.qr(w,'reduced')
return q
def forward_realnvp(self, x):
scaling_before_exp_list = []
for i in range(self.blocks):
# Apply random orthonormal matrix
x = torch.matmul(x, self.orthonormal_matrices[i])
# Apply coupling layer
x, scaling_before_exp = self.coupling_layers[i].forward(x)
scaling_before_exp_list.append(scaling_before_exp)
self.scaling_before_exp_list = scaling_before_exp_list
return x
def encode(self, x):
# Encoding is the forward pass through the RealNVP model
return self.forward_realnvp(x)
def decode(self, z):
# Modify z to zero out dimensions beyond k for the reconstruction
z_reconstructed = z.clone()
if self.k < self.input_size:
z_reconstructed[:, self.k:] = 0 # Zero out dimensions beyond k
# Proceed with the original decoding process
for i in reversed(range(self.blocks)):
z = self.coupling_layers[i].backward(z)
z_reconstructed = self.coupling_layers[i].backward(z_reconstructed)
z = torch.matmul(z, self.orthonormal_matrices[i].t())
z_reconstructed = torch.matmul(z_reconstructed, self.orthonormal_matrices[i].t())
return z, z_reconstructed
def sample(self, num_samples=1000):
# Generate random samples from a standard normal distribution
with torch.no_grad():
z = torch.randn(num_samples, self.input_size)
# Apply the reverse transformations (decoder) to generate synthetic samples
_,synthetic_samples = self.decode(z)
return synthetic_samples
def sample_only_important(self, num_samples=1000):
# Generate random samples from a standard normal distribution
with torch.no_grad():
z_1 = torch.randn(num_samples, self.k)
z_2 = torch.zeros(num_samples, self.input_size - self.k)
z = torch.cat((z_1, z_2), dim=1)
# Apply the reverse transformations (decoder) to generate synthetic samples
_,synthetic_samples = self.decode(z)
return synthetic_samples
def sample_only_unimportant(self, num_samples=1000):
# Generate random samples from a standard normal distribution
with torch.no_grad():
z_1 = torch.randn(1, self.k).repeat(num_samples, 1)
z_2 = torch.randn(num_samples, self.input_size - self.k)
z = torch.cat((z_1, z_2), dim=1)
# Apply the reverse transformations (decoder) to generate synthetic samples
_,synthetic_samples = self.decode(z)
return synthetic_samples
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
def train_and_evaluate(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1):
"""
Train the RealNVP model and evaluate on a validation dataset.
Args:
- model (RealNVP): The RealNVP model to be trained.
- train_loader (DataLoader): DataLoader for the training dataset.
- val_loader (DataLoader): DataLoader for the validation dataset.
- num_epochs (int): Number of training epochs.
- lr (float): Learning rate for the optimizer.
- print_after (int): Number of epochs after which to print the training and validation loss.
Returns:
- train_losses (list): List of training losses for each epoch.
- val_losses (list): List of validation losses for each epoch.
"""
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
mse_loss = nn.MSELoss()
train_losses_nll = []
val_losses_nll = []
train_losses_recons = []
val_losses_recons = []
# Training phase
model.train() # Set the model to training mode
for epoch in range(num_epochs):
total_train_loss_nll = 0.0
total_train_loss_recons = 0.0
for data in train_loader:
inputs= data
# Zero the gradients
optimizer.zero_grad()
# NLL Loss calculation
encoded = model.encode(inputs)
train_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))
# Reconstruction loss calculation
_, decoded = model.decode(encoded)
train_loss_recons = mse_loss(inputs, decoded)
# Backward pass (gradient computation)
loss = train_loss_nll + train_loss_recons
loss.backward()
### added recently: clip the gradients
clip_grad_norm_(model.parameters(), max_norm=1.0) # Adjust max_norm as needed
# Update weights
optimizer.step()
total_train_loss_nll += train_loss_nll.item()
total_train_loss_recons += train_loss_recons.item()
# Average training loss for the epoch
average_train_loss_nll = total_train_loss_nll / len(train_loader)
average_train_loss_recons = total_train_loss_recons / len(train_loader)
# Validation phase
model.eval() # Set the model to evaluation mode
if val_loader is not None:
total_val_loss_nll = 0.0
total_val_loss_recons = 0.0
with torch.no_grad():
for val_data in val_loader:
val_inputs = val_data
# NLL Loss calculation
encoded = model.encode(val_inputs)
val_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(val_loader))
# Reconstruction loss calculation
_, decoded = model.decode(encoded)
val_loss_recons = mse_loss(val_inputs, decoded)
total_val_loss_nll += val_loss_nll.item()
total_val_loss_recons += val_loss_recons.item()
# Average validation loss for the epoch
average_val_loss_nll = total_val_loss_nll / len(val_loader)
average_val_loss_recons = total_val_loss_recons / len(val_loader)
# Print training and validation losses together
if (epoch + 1) % print_after == 0:
print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss_recons+average_train_loss_nll}, Validation Loss: {average_val_loss_nll+average_val_loss_recons}")
# Append losses to the lists
train_losses_nll.append(average_train_loss_nll)
val_losses_nll.append(average_val_loss_nll)
train_losses_recons.append(average_train_loss_recons)
val_losses_recons.append(average_val_loss_recons)
# Set the model back to training mode
model.train()
print("Training complete")
return train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons
k_values = [2,4,8]
dataset_percentage = 1.0
# Create data loader for the fixed dataset size
data_considered = train_datasets[dataset_percentage]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset['X'], batch_size=32, shuffle=True)
for k in k_values:
print(f"\nTraining for k={k}")
# Instantiate the model
model = RealNVP_bottleneck(input_size=64, hidden_size=100, blocks=10,k=k)
# Train the model
train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_evaluate(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1)
train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)
# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()
# Example usage:
plot_code_distribution(model=model, test_loader=val_loader)
plt.show()
### plot the synthetic data
synthetic_data=model.sample(num_samples=len(data_considered))
visualize_synthetic_data(synthetic_data)
Training for k=2 Epoch 1/20, Training Loss: 1134181.2054025438, Validation Loss: 393.4980107943217 Epoch 2/20, Training Loss: 104.72341115739611, Validation Loss: 277.5674916903178 Epoch 3/20, Training Loss: 85.00006866455078, Validation Loss: 252.0516587893168 Epoch 4/20, Training Loss: 77.07950909932454, Validation Loss: 236.80850172042847 Epoch 5/20, Training Loss: 72.53351398044163, Validation Loss: 228.5046550432841 Epoch 6/20, Training Loss: 69.57301697201198, Validation Loss: 222.44894289970398 Epoch 7/20, Training Loss: 67.61917934417724, Validation Loss: 218.56156961123148 Epoch 8/20, Training Loss: 65.83680402967664, Validation Loss: 215.66464479764304 Epoch 9/20, Training Loss: 64.46760686238606, Validation Loss: 212.84963099161783 Epoch 10/20, Training Loss: 63.096194140116374, Validation Loss: 213.09209056695303 Epoch 11/20, Training Loss: 62.07158679962158, Validation Loss: 207.5499466260274 Epoch 12/20, Training Loss: 61.22519813113742, Validation Loss: 212.94606955846152 Epoch 13/20, Training Loss: 60.01767086452908, Validation Loss: 206.5381192366282 Epoch 14/20, Training Loss: 59.57084166208903, Validation Loss: 203.71867847442627 Epoch 15/20, Training Loss: 58.656814617580835, Validation Loss: 205.08861184120178 Epoch 16/20, Training Loss: 58.426832644144696, Validation Loss: 205.47613739967346 Epoch 17/20, Training Loss: 57.60618413289388, Validation Loss: 204.82380390167236 Epoch 18/20, Training Loss: 56.88742605845133, Validation Loss: 205.32237219810486 Epoch 19/20, Training Loss: 56.780501280890576, Validation Loss: 206.95415838559467 Epoch 20/20, Training Loss: 56.04133735232883, Validation Loss: 202.103222211202 Training complete
Training for k=4 Epoch 1/20, Training Loss: 1764351.6226888022, Validation Loss: 6121.43034807841 Epoch 2/20, Training Loss: 305.77129427591956, Validation Loss: 328.2409567832947 Epoch 3/20, Training Loss: 93.7455064561632, Validation Loss: 274.5516738096873 Epoch 4/20, Training Loss: 80.29229000939263, Validation Loss: 248.36701107025146 Epoch 5/20, Training Loss: 74.14117916954888, Validation Loss: 236.20576540629068 Epoch 6/20, Training Loss: 69.56096011267768, Validation Loss: 228.18267114957172 Epoch 7/20, Training Loss: 66.54357260598077, Validation Loss: 219.3884553114573 Epoch 8/20, Training Loss: 64.57242608600193, Validation Loss: 215.73117335637411 Epoch 9/20, Training Loss: 62.44744293424819, Validation Loss: 211.98541418711343 Epoch 10/20, Training Loss: 60.773972034454346, Validation Loss: 208.40398891766867 Epoch 11/20, Training Loss: 59.45615416632758, Validation Loss: 205.69273841381073 Epoch 12/20, Training Loss: 58.051778125762944, Validation Loss: 206.46651673316956 Epoch 13/20, Training Loss: 57.31363100475735, Validation Loss: 202.31709150473276 Epoch 14/20, Training Loss: 56.459687858157686, Validation Loss: 203.77659797668457 Epoch 15/20, Training Loss: 55.96581435733371, Validation Loss: 201.2215979496638 Epoch 16/20, Training Loss: 55.15948048697578, Validation Loss: 199.77739560604095 Epoch 17/20, Training Loss: 55.31225707795885, Validation Loss: 200.20230305194855 Epoch 18/20, Training Loss: 54.43091343773736, Validation Loss: 199.18999723593393 Epoch 19/20, Training Loss: 53.94967920515273, Validation Loss: 193.5152560075124 Epoch 20/20, Training Loss: 53.28515498903063, Validation Loss: 197.675413052241 Training complete
Training for k=8 Epoch 1/20, Training Loss: 3957463.0584517587, Validation Loss: 678.2395188013713 Epoch 2/20, Training Loss: 127.46482734680175, Validation Loss: 297.8124193350474 Epoch 3/20, Training Loss: 85.34479524824354, Validation Loss: 255.7841518719991 Epoch 4/20, Training Loss: 75.56492190890842, Validation Loss: 241.27586038907367 Epoch 5/20, Training Loss: 70.02874931759305, Validation Loss: 227.84788632392883 Epoch 6/20, Training Loss: 66.06599095662435, Validation Loss: 219.99094394842783 Epoch 7/20, Training Loss: 63.42344832950168, Validation Loss: 215.41912790139514 Epoch 8/20, Training Loss: 61.34067319234212, Validation Loss: 213.2213078737259 Epoch 9/20, Training Loss: 59.875758753882515, Validation Loss: 210.1514040629069 Epoch 10/20, Training Loss: 58.376855532328285, Validation Loss: 208.3204576174418 Epoch 11/20, Training Loss: 56.99812969631619, Validation Loss: 199.2074755827586 Epoch 12/20, Training Loss: 55.90390529632568, Validation Loss: 201.8850393295288 Epoch 13/20, Training Loss: 55.46685161590577, Validation Loss: 201.70479098955792 Epoch 14/20, Training Loss: 54.837560865614144, Validation Loss: 200.57280039787292 Epoch 15/20, Training Loss: 54.00313622156779, Validation Loss: 201.01371534665427 Epoch 16/20, Training Loss: 53.054822762807206, Validation Loss: 200.3045927286148 Epoch 17/20, Training Loss: 52.68340014351739, Validation Loss: 200.4621553023656 Epoch 18/20, Training Loss: 52.258125665452745, Validation Loss: 198.67079102993011 Epoch 19/20, Training Loss: 51.502514394124354, Validation Loss: 199.64668464660645 Epoch 20/20, Training Loss: 51.13463984595405, Validation Loss: 200.92397185166678 Training complete
We see the reconstruction performs already far better with the bottleneck. We get the best results with k = 4. So lets try two different sampling techniques
input_size = 64
hidden_size = 100
blocks = 10
print_after=1
dataset_percentage = 1.0
batch_size=32
data_considered = train_datasets[dataset_percentage]['X']
train_loader = torch.utils.data.DataLoader(data_considered, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset['X'], batch_size=32, shuffle=True)
# Instantiate the model
model = RealNVP_bottleneck(input_size=input_size, hidden_size=hidden_size, blocks=blocks,k=4)
# Train the model
train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_evaluate(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1)
train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)
### plot the synthetic data
synthetic_data=model.sample_only_important(num_samples=len(data_considered))
visualize_synthetic_data(synthetic_data, title="Sampling only important features")
synthetic_data=model.sample_only_unimportant(num_samples=len(data_considered))
visualize_synthetic_data(synthetic_data, title="Sampling only unimportant features")
Epoch 1/20, Training Loss: 506830.6718499078, Validation Loss: 428.4326847394308 Epoch 2/20, Training Loss: 107.78536251915827, Validation Loss: 291.4867718219757 Epoch 3/20, Training Loss: 85.82897991604275, Validation Loss: 259.36583797136944 Epoch 4/20, Training Loss: 77.02423313988581, Validation Loss: 241.9788273970286 Epoch 5/20, Training Loss: 71.95412707858615, Validation Loss: 232.03203002611795 Epoch 6/20, Training Loss: 68.29618448681302, Validation Loss: 227.44942140579224 Epoch 7/20, Training Loss: 65.99424372778998, Validation Loss: 224.72661836942038 Epoch 8/20, Training Loss: 63.48433694839478, Validation Loss: 218.69628206888834 Epoch 9/20, Training Loss: 61.90843456056383, Validation Loss: 216.30499251683554 Epoch 10/20, Training Loss: 60.55102155473497, Validation Loss: 214.02308400472006 Epoch 11/20, Training Loss: 59.381386015150284, Validation Loss: 207.5036437511444 Epoch 12/20, Training Loss: 58.41290695402357, Validation Loss: 206.32842751344046 Epoch 13/20, Training Loss: 57.291498226589624, Validation Loss: 206.07214486598969 Epoch 14/20, Training Loss: 57.13765395482381, Validation Loss: 205.6906545559565 Epoch 15/20, Training Loss: 56.12033676571316, Validation Loss: 209.88612131277722 Epoch 16/20, Training Loss: 55.83825403849284, Validation Loss: 204.41844717661542 Epoch 17/20, Training Loss: 54.830014557308616, Validation Loss: 198.95498553911847 Epoch 18/20, Training Loss: 54.01373918321398, Validation Loss: 203.80519477526346 Epoch 19/20, Training Loss: 53.89666982226902, Validation Loss: 198.23216180006662 Epoch 20/20, Training Loss: 53.23209324942695, Validation Loss: 199.4986126422882 Training complete
We see the algorithm performs as expected, when we sample the important features, the numbers/complete images changes. If we only sample the unimportant features we see only small changes in the printed images
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image, ImageFilter # install 'pillow ' to get PIL
import matplotlib.pyplot as plt
# define a functor to downsample images
class DownsampleTransform:
def __init__ (self , target_shape , algorithm = Image.Resampling.LANCZOS):
self.width, self.height = target_shape
self.algorithm = algorithm
def __call__ (self , img):
img = img.resize(( self.width +2, self.height +2) , self.algorithm )
img = img.crop((1 , 1, self.width +1, self.height +1))
return img
# concatenate a few transforms
transform = transforms.Compose ([
DownsampleTransform(target_shape =(8 ,8)),
transforms.Grayscale(num_output_channels =1) ,
transforms.ToTensor()
])
# download MNIST
mnist_dataset = datasets.MNIST( root ='./data', train =True ,
transform = transform, download = True )
# create a DataLoader that serves minibatches of size 100
data_loader = DataLoader(mnist_dataset , batch_size =100 , shuffle = True )
mnist_test_dataset = datasets.MNIST( root ='./data', train =False ,
transform = transform, download = True )
val_loader = DataLoader(mnist_test_dataset , batch_size =100 , shuffle = True )
# visualize the first batch of downsampled MNIST images
def show_first_batch(data_loader):
for batch in data_loader:
x, y = batch
fig = plt.figure(figsize =(10 , 10))
for i, img in enumerate(x):
ax = fig.add_subplot(10 , 10, i+1)
ax.imshow(img.reshape(8, 8), cmap ='gray')
ax.axis('off')
break
show_first_batch(data_loader)
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
def train_and_evaluate(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1):
"""
Train the RealNVP model and evaluate on a validation dataset.
Args:
- model (RealNVP): The RealNVP model to be trained.
- train_loader (DataLoader): DataLoader for the training dataset.
- val_loader (DataLoader): DataLoader for the validation dataset.
- num_epochs (int): Number of training epochs.
- lr (float): Learning rate for the optimizer.
- print_after (int): Number of epochs after which to print the training and validation loss.
Returns:
- train_losses (list): List of training losses for each epoch.
- val_losses (list): List of validation losses for each epoch.
"""
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
mse_loss = nn.MSELoss()
train_losses_nll = []
val_losses_nll = []
train_losses_recons = []
val_losses_recons = []
# Training phase
model.train() # Set the model to training mode
for epoch in range(num_epochs):
total_train_loss_nll = 0.0
total_train_loss_recons = 0.0
for batch in train_loader:
X, y = batch
inputs= X.reshape(len(y),64)
# Zero the gradients
optimizer.zero_grad()
# NLL Loss calculation
encoded = model.encode(inputs)
train_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))
# Reconstruction loss calculation
decoded = model.decode(encoded)
train_loss_recons = mse_loss(inputs, decoded)
# Backward pass (gradient computation)
loss = train_loss_nll + train_loss_recons
loss.backward()
### added recently: clip the gradients
clip_grad_norm_(model.parameters(), max_norm=1.0) # Adjust max_norm as needed
# Update weights
optimizer.step()
total_train_loss_nll += train_loss_nll.item()
total_train_loss_recons += train_loss_recons.item()
# Average training loss for the epoch
average_train_loss_nll = total_train_loss_nll / len(train_loader)
average_train_loss_recons = total_train_loss_recons / len(train_loader)
# Validation phase
model.eval() # Set the model to evaluation mode
if val_loader is not None:
total_val_loss_nll = 0.0
total_val_loss_recons = 0.0
with torch.no_grad():
for batch in val_loader:
X,y = batch
val_inputs = X.reshape(len(y),64)
# NLL Loss calculation
encoded = model.encode(val_inputs)
val_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(val_loader))
# Reconstruction loss calculation
decoded = model.decode(encoded)
val_loss_recons = mse_loss(val_inputs, decoded)
total_val_loss_nll += val_loss_nll.item()
total_val_loss_recons += val_loss_recons.item()
# Average validation loss for the epoch
average_val_loss_nll = total_val_loss_nll / len(val_loader)
average_val_loss_recons = total_val_loss_recons / len(val_loader)
# Print training and validation losses together
if (epoch + 1) % print_after == 0:
print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss_recons+average_train_loss_nll}, Validation Loss: {average_val_loss_nll+average_val_loss_recons}")
# Append losses to the lists
train_losses_nll.append(average_train_loss_nll)
val_losses_nll.append(average_val_loss_nll)
train_losses_recons.append(average_train_loss_recons)
val_losses_recons.append(average_val_loss_recons)
# Set the model back to training mode
model.train()
print("Training complete")
return train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons
def plot_code_distribution(model, test_loader):
"""
Plot the code distribution obtained by applying the trained RealNVP model to a test dataset.
Args:
- model (RealNVP): Trained RealNVP model.
- test_loader (DataLoader): DataLoader for the test dataset.
- num_samples (int): Number of samples to visualize.
Returns:
None (displays the plot).
"""
model.eval() # Set the model to evaluation mode
fig, axs = plt.subplots(2, 5, figsize=(20, 7))
with torch.no_grad():
# Concatenate multiple batches to obtain more samples
test_samples = torch.cat([X for (X,y) in test_loader], dim=0)
test_samples = test_samples.reshape(len(test_samples), 64)
# Assuming your model has an `encode` method
code_samples = model.encode(test_samples)
# Convert PyTorch tensor to numpy array
code_np = code_samples.numpy()
dim_1 = 0
dim_2 = 1
for i in range(2):
for j in range(5):
# Scatter plot of code distribution
axs[i,j].scatter(code_np[:, dim_1], code_np[:, dim_2], label='Code Distribution', alpha=0.5)
axs[i,j].set_xlabel(f"Code Dimension {dim_1}")
axs[i,j].set_ylabel(f"Code Dimension {dim_2}")
axs[i,j].set_title(f'Code Distribution: {dim_2}')
dim_1 += 1
dim_2 += 1
plt.tight_layout()
plt.show()
input_size = 64
hidden_size = 200
blocks = 10
print_after=1
dataset_percentage = 1.0
batch_size=100
# Instantiate the model
model = RealNVP(input_size=input_size, hidden_size=hidden_size, blocks=blocks)
# Train the model
train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_evaluate(model, data_loader, val_loader, num_epochs=10, lr=0.005, print_after=1)
train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)
# plotting the loss
plot_losses(train_losses, val_losses, want_log_scale=0)
plt.show()
# Example usage:
plot_code_distribution(model=model, test_loader=val_loader)
plt.show()
### plot the synthetic data
synthetic_data=model.sample(num_samples=100)
visualize_synthetic_data(synthetic_data, title="Sampling only important features")
Epoch 1/10, Training Loss: -19.166037408212674, Validation Loss: -130.39732536315904 Epoch 2/10, Training Loss: -22.551581888198704, Validation Loss: -138.417531890869 Epoch 3/10, Training Loss: -23.38100530624374, Validation Loss: -142.53491348266584 Epoch 4/10, Training Loss: -23.802392495473065, Validation Loss: -143.01150100707991 Epoch 5/10, Training Loss: -24.102571767171064, Validation Loss: -145.12324249267562 Epoch 6/10, Training Loss: -24.330206867853637, Validation Loss: -146.54541503906233 Epoch 7/10, Training Loss: -24.49192481994612, Validation Loss: -146.0686399841307 Epoch 8/10, Training Loss: -24.635013628005815, Validation Loss: -146.93343902587873 Epoch 9/10, Training Loss: -24.749749383926222, Validation Loss: -148.22393341064435 Epoch 10/10, Training Loss: -24.859632444381543, Validation Loss: -148.3804725646971 Training complete
The results look better than with the digits dataset, but using a bottleneck is still superior in performance and training time
Lets continue task 3 with a conditional INN
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import torch
# Load the digits dataset
digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2)
#### data for the two-moons model
from torch.utils.data import TensorDataset, DataLoader
# Define a custom dataset
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.labels[index]
return x, y
# Define model parameters
input_size = 64
hidden_size = 100
condition_size = 10
blocks = 10
percentage = 1.0
num_epochs = 10
lr = 0.005
y_train = torch.arange(condition_size)[y_train].long()
y_test = torch.arange(condition_size)[y_test].long()
# Initialize the model
conditional_inn_model = ConditionalRealNVP(input_size, hidden_size, condition_size, blocks)
train_dataset = CustomDataset(torch.FloatTensor(X_train), y_train)
val_dataset = CustomDataset(torch.FloatTensor(X_test), y_test)
# Define batch size
batch_size = 32
# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Task 1: Train the Conditional INN
train_loss, val_loss= train_and_validate_conditional_nvp(conditional_inn_model, train_loader, val_loader,
num_epochs=num_epochs, lr=lr, print_after=1)
# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()
conditions_all_labels = torch.eye(condition_size)
synthetic_data=conditional_inn_model.sample(num_samples=10, conditions= conditions_all_labels)
visualize_synthetic_data(synthetic_data,title="Synthetic digits from 0 to 9")
plt.show()
Epoch 1/10, Training Loss: 14079.221368747287, Validation Loss: 251.96936988830566 Epoch 2/10, Training Loss: 66.6776138305664, Validation Loss: 224.05490080515543 Epoch 3/10, Training Loss: 59.361120012071396, Validation Loss: 205.93127663930258 Epoch 4/10, Training Loss: 55.38476036919488, Validation Loss: 202.0959882736206 Epoch 5/10, Training Loss: 52.79541244506836, Validation Loss: 191.49610010782877 Epoch 6/10, Training Loss: 50.71405766805013, Validation Loss: 193.31329123179117 Epoch 7/10, Training Loss: 48.682631174723305, Validation Loss: 184.10962549845377 Epoch 8/10, Training Loss: 47.73568098280165, Validation Loss: 183.49630864461264 Epoch 9/10, Training Loss: 46.577716488308376, Validation Loss: 179.7266502380371 Epoch 10/10, Training Loss: 45.92713555230035, Validation Loss: 179.190260887146 Training complete
--------------------------------------------------------------------------- NameError Traceback (most recent call last) c:\Users\luke\OneDrive\Dokumente\UniHeidelberg\Master\Semester3\Generative Neural Networks\code\Exercise_3_GNN_for_science.ipynb Cell 72 line 3 <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=28'>29</a> train_loss, val_loss= train_and_validate_conditional_nvp(conditional_inn_model, train_loader, val_loader, <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=29'>30</a> num_epochs=num_epochs, lr=lr, print_after=1) <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=31'>32</a> # plotting the loss ---> <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=32'>33</a> plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0) <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=33'>34</a> plt.show() <a href='vscode-notebook-cell:/c%3A/Users/luke/OneDrive/Dokumente/UniHeidelberg/Master/Semester3/Generative%20Neural%20Networks/code/Exercise_3_GNN_for_science.ipynb#Y123sZmlsZQ%3D%3D?line=35'>36</a> conditions_all_labels = torch.eye(condition_size) NameError: name 'train_losses' is not defined
The results look pretty good! Lets try different hyperparameters!
hidden_sizes = [100, 200, 400]
blocks = [2, 5, 10]
input_size = 64
for hidden_size in hidden_sizes:
for block in blocks:
print(f"\nTraining for hidden_size={hidden_size}, blocks = {block}")
# Instantiate the model
conditional_inn_model = ConditionalRealNVP(input_size, hidden_size, condition_size, block)
train_loss, val_loss= train_and_validate_conditional_nvp(conditional_inn_model, train_loader, val_loader,
num_epochs=30, lr=lr, print_after=1)
# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()
conditions_all_labels = torch.eye(condition_size)
synthetic_data=conditional_inn_model.sample(num_samples=10, conditions= conditions_all_labels)
visualize_synthetic_data(synthetic_data,title="Synthetic digits from 0 to 9")
plt.show()
Training for hidden_size=100, blocks = 2 Epoch 1/30, Training Loss: 293.8609095255534, Validation Loss: 522.8373781840006 Epoch 2/30, Training Loss: 137.28045535617406, Validation Loss: 457.2848320007324 Epoch 3/30, Training Loss: 118.8162612915039, Validation Loss: 405.67692375183105 Epoch 4/30, Training Loss: 107.44375440809462, Validation Loss: 390.22615814208984 Epoch 5/30, Training Loss: 100.79131232367621, Validation Loss: 340.93681780497235 Epoch 6/30, Training Loss: 94.5724119398329, Validation Loss: 333.0782648722331 Epoch 7/30, Training Loss: 91.05331556532118, Validation Loss: 323.25561205546063 Epoch 8/30, Training Loss: 89.2670910305447, Validation Loss: 338.40706125895184 Epoch 9/30, Training Loss: 87.65778791639539, Validation Loss: 304.7028503417969 Epoch 10/30, Training Loss: 84.94269849989149, Validation Loss: 308.5491994222005 Epoch 11/30, Training Loss: 83.56937561035156, Validation Loss: 300.5091025034587 Epoch 12/30, Training Loss: 82.23126559787326, Validation Loss: 293.45045534769696 Epoch 13/30, Training Loss: 80.81693674723307, Validation Loss: 300.28919982910156 Epoch 14/30, Training Loss: 80.5521474202474, Validation Loss: 297.10655721028644 Epoch 15/30, Training Loss: 79.80333201090495, Validation Loss: 291.25659370422363 Epoch 16/30, Training Loss: 78.20782521565755, Validation Loss: 306.5996602376302 Epoch 17/30, Training Loss: 77.80365125868056, Validation Loss: 290.4383665720622 Epoch 18/30, Training Loss: 78.24269239637587, Validation Loss: 275.5026200612386 Epoch 19/30, Training Loss: 77.60413360595703, Validation Loss: 281.0929635365804 Epoch 20/30, Training Loss: 76.18750779893664, Validation Loss: 275.5623073577881 Epoch 21/30, Training Loss: 76.36902414957682, Validation Loss: 287.7348397572835 Epoch 22/30, Training Loss: 76.31015421549479, Validation Loss: 277.7603931427002 Epoch 23/30, Training Loss: 74.88749016655817, Validation Loss: 274.15149815877277 Epoch 24/30, Training Loss: 74.62725897894965, Validation Loss: 288.30015627543133 Epoch 25/30, Training Loss: 74.62852376302084, Validation Loss: 266.2942295074463 Epoch 26/30, Training Loss: 72.85536702473958, Validation Loss: 286.1646811167399 Epoch 27/30, Training Loss: 74.18861253526475, Validation Loss: 269.11271413167316 Epoch 28/30, Training Loss: 73.26144222683376, Validation Loss: 276.10282198588055 Epoch 29/30, Training Loss: 72.91037936740452, Validation Loss: 266.6187432607015 Epoch 30/30, Training Loss: 72.74726952446832, Validation Loss: 269.72121556599933 Training complete
Training for hidden_size=100, blocks = 5 Epoch 1/30, Training Loss: 245.61961432562933, Validation Loss: 288.1628958384196 Epoch 2/30, Training Loss: 73.73544447157118, Validation Loss: 253.28568267822266 Epoch 3/30, Training Loss: 65.67341079711915, Validation Loss: 233.32082843780518 Epoch 4/30, Training Loss: 61.25560048421224, Validation Loss: 223.8289836247762 Epoch 5/30, Training Loss: 58.051698981391056, Validation Loss: 212.6532974243164 Epoch 6/30, Training Loss: 55.78813883463542, Validation Loss: 214.10021591186523 Epoch 7/30, Training Loss: 53.83916973537869, Validation Loss: 211.61553732554117 Epoch 8/30, Training Loss: 52.40808283487956, Validation Loss: 207.70124022165933 Epoch 9/30, Training Loss: 51.4053960164388, Validation Loss: 201.84393946329752 Epoch 10/30, Training Loss: 50.25626102023654, Validation Loss: 203.0097745259603 Epoch 11/30, Training Loss: 49.2913705613878, Validation Loss: 202.8919070561727 Epoch 12/30, Training Loss: 48.77919074164497, Validation Loss: 199.91488869984946 Epoch 13/30, Training Loss: 48.111915503607854, Validation Loss: 198.21980253855386 Epoch 14/30, Training Loss: 47.437728034125435, Validation Loss: 199.25023110707602 Epoch 15/30, Training Loss: 47.18081834581163, Validation Loss: 198.07999897003174 Epoch 16/30, Training Loss: 46.162787458631726, Validation Loss: 198.0883830388387 Epoch 17/30, Training Loss: 46.0449708726671, Validation Loss: 197.23108418782553 Epoch 18/30, Training Loss: 44.870973375108505, Validation Loss: 196.50633462270102 Epoch 19/30, Training Loss: 44.54693255954319, Validation Loss: 193.11839516957602 Epoch 20/30, Training Loss: 43.91679543389215, Validation Loss: 196.35790157318115 Epoch 21/30, Training Loss: 43.8365002102322, Validation Loss: 191.82647037506104 Epoch 22/30, Training Loss: 43.50753716362847, Validation Loss: 195.8693552017212 Epoch 23/30, Training Loss: 43.52292590671115, Validation Loss: 191.5731871922811 Epoch 24/30, Training Loss: 42.970313771565756, Validation Loss: 193.97451178232828 Epoch 25/30, Training Loss: 42.77379387749566, Validation Loss: 188.9208027521769 Epoch 26/30, Training Loss: 42.4356568230523, Validation Loss: 191.19274870554605 Epoch 27/30, Training Loss: 41.90263214111328, Validation Loss: 194.66618855794272 Epoch 28/30, Training Loss: 42.02589831882053, Validation Loss: 193.22452799479166 Epoch 29/30, Training Loss: 41.702737511528866, Validation Loss: 186.7782309850057 Epoch 30/30, Training Loss: 41.530994245741105, Validation Loss: 199.4575522740682 Training complete
Training for hidden_size=100, blocks = 10 Epoch 1/30, Training Loss: 1349.3783764309353, Validation Loss: 252.0936533610026 Epoch 2/30, Training Loss: 64.5319458855523, Validation Loss: 225.27423095703125 Epoch 3/30, Training Loss: 57.80359793768989, Validation Loss: 210.5755526224772 Epoch 4/30, Training Loss: 53.993755171034074, Validation Loss: 205.24729283650717 Epoch 5/30, Training Loss: 51.40060594346788, Validation Loss: 198.3798786799113 Epoch 6/30, Training Loss: 49.362469312879774, Validation Loss: 191.2987314860026 Epoch 7/30, Training Loss: 47.77900839911567, Validation Loss: 197.155779838562 Epoch 8/30, Training Loss: 46.82213287353515, Validation Loss: 192.00089104970297 Epoch 9/30, Training Loss: 45.66339738633898, Validation Loss: 186.95856475830078 Epoch 10/30, Training Loss: 44.15901014539931, Validation Loss: 186.4817123413086 Epoch 11/30, Training Loss: 43.3345470852322, Validation Loss: 187.1396245956421 Epoch 12/30, Training Loss: 42.67904968261719, Validation Loss: 188.0238265991211 Epoch 13/30, Training Loss: 41.80822982788086, Validation Loss: 185.34127044677734 Epoch 14/30, Training Loss: 41.55522350735134, Validation Loss: 187.3965581258138 Epoch 15/30, Training Loss: 41.04959403143989, Validation Loss: 184.37189610799155 Epoch 16/30, Training Loss: 40.496132405598956, Validation Loss: 184.7194267908732 Epoch 17/30, Training Loss: 39.74103876749675, Validation Loss: 184.01869996388754 Epoch 18/30, Training Loss: 39.6324717203776, Validation Loss: 182.57812881469727 Epoch 19/30, Training Loss: 38.70297181871202, Validation Loss: 183.27468649546304 Epoch 20/30, Training Loss: 38.544365607367624, Validation Loss: 184.38559182484946 Epoch 21/30, Training Loss: 38.10673684014215, Validation Loss: 188.7575225830078 Epoch 22/30, Training Loss: 38.150281185574, Validation Loss: 182.02412446339926 Epoch 23/30, Training Loss: 37.513829718695746, Validation Loss: 184.8826446533203 Epoch 24/30, Training Loss: 36.985147603352864, Validation Loss: 187.850931485494 Epoch 25/30, Training Loss: 37.024013943142364, Validation Loss: 188.29494953155518 Epoch 26/30, Training Loss: 36.79429465399848, Validation Loss: 181.92155679066977 Epoch 27/30, Training Loss: 36.5680906507704, Validation Loss: 188.49392795562744 Epoch 28/30, Training Loss: 36.31601587931315, Validation Loss: 185.95901012420654 Epoch 29/30, Training Loss: 36.24478285047743, Validation Loss: 185.64122422536215 Epoch 30/30, Training Loss: 35.76986846923828, Validation Loss: 186.4094565709432 Training complete
Training for hidden_size=200, blocks = 2 Epoch 1/30, Training Loss: 251.62340698242187, Validation Loss: 493.00861167907715 Epoch 2/30, Training Loss: 133.07212592230903, Validation Loss: 445.5189565022786 Epoch 3/30, Training Loss: 120.69439256456164, Validation Loss: 412.8376407623291 Epoch 4/30, Training Loss: 111.98785451253255, Validation Loss: 395.0897928873698 Epoch 5/30, Training Loss: 106.65084398057726, Validation Loss: 371.44888496398926 Epoch 6/30, Training Loss: 104.55206366644965, Validation Loss: 375.08094724019367 Epoch 7/30, Training Loss: 100.62608981662326, Validation Loss: 354.244104385376 Epoch 8/30, Training Loss: 101.3826168484158, Validation Loss: 361.2765522003174 Epoch 9/30, Training Loss: 97.17907290988498, Validation Loss: 350.2586212158203 Epoch 10/30, Training Loss: 99.39001719156902, Validation Loss: 375.323211034139 Epoch 11/30, Training Loss: 96.24877268473307, Validation Loss: 342.65224011739093 Epoch 12/30, Training Loss: 95.22671322292751, Validation Loss: 371.6747303009033 Epoch 13/30, Training Loss: 94.0648691813151, Validation Loss: 335.6468200683594 Epoch 14/30, Training Loss: 94.47991282145182, Validation Loss: 333.73937034606934 Epoch 15/30, Training Loss: 92.03536478678386, Validation Loss: 328.69567171732587 Epoch 16/30, Training Loss: 92.69359588623047, Validation Loss: 339.59826405843097 Epoch 17/30, Training Loss: 90.38678927951389, Validation Loss: 330.8305486043294 Epoch 18/30, Training Loss: 90.6170661078559, Validation Loss: 326.24981753031415 Epoch 19/30, Training Loss: 90.67704806857638, Validation Loss: 326.8300698598226 Epoch 20/30, Training Loss: 90.08084496392144, Validation Loss: 345.3441562652588 Epoch 21/30, Training Loss: 90.32333001030815, Validation Loss: 332.39401563008624 Epoch 22/30, Training Loss: 89.01673482259115, Validation Loss: 329.86750348409015 Epoch 23/30, Training Loss: 89.7814205593533, Validation Loss: 328.3523635864258 Epoch 24/30, Training Loss: 88.09744957817925, Validation Loss: 319.090092976888 Epoch 25/30, Training Loss: 87.52293056911893, Validation Loss: 344.37699190775555 Epoch 26/30, Training Loss: 87.02120480007595, Validation Loss: 350.18481890360516 Epoch 27/30, Training Loss: 84.56066385904948, Validation Loss: 311.2099526723226 Epoch 28/30, Training Loss: 85.24506157769098, Validation Loss: 312.0635674794515 Epoch 29/30, Training Loss: 84.36019151475695, Validation Loss: 322.24486605326337 Epoch 30/30, Training Loss: 85.31957227918836, Validation Loss: 306.62286885579425 Training complete
Training for hidden_size=200, blocks = 5 Epoch 1/30, Training Loss: 314.2896986219618, Validation Loss: 290.5458056131999 Epoch 2/30, Training Loss: 74.96810743543837, Validation Loss: 255.79893811543783 Epoch 3/30, Training Loss: 66.97058885362414, Validation Loss: 235.95261446634927 Epoch 4/30, Training Loss: 62.063047112358944, Validation Loss: 225.87391726175943 Epoch 5/30, Training Loss: 58.89331987169054, Validation Loss: 219.84676583607992 Epoch 6/30, Training Loss: 56.11679161919488, Validation Loss: 213.3170550664266 Epoch 7/30, Training Loss: 54.45528928968641, Validation Loss: 207.7753407160441 Epoch 8/30, Training Loss: 52.892186652289496, Validation Loss: 206.37459564208984 Epoch 9/30, Training Loss: 51.78697018093533, Validation Loss: 210.35921986897787 Epoch 10/30, Training Loss: 50.40202967325846, Validation Loss: 202.17945830027261 Epoch 11/30, Training Loss: 49.897619289822046, Validation Loss: 200.61742146809897 Epoch 12/30, Training Loss: 49.18618816799588, Validation Loss: 200.7643254597982 Epoch 13/30, Training Loss: 47.86549309624566, Validation Loss: 199.72708129882812 Epoch 14/30, Training Loss: 47.19406238132053, Validation Loss: 198.47883065541586 Epoch 15/30, Training Loss: 46.79597422281901, Validation Loss: 196.3492390314738 Epoch 16/30, Training Loss: 46.761542850070526, Validation Loss: 196.71438121795654 Epoch 17/30, Training Loss: 45.04479276869032, Validation Loss: 193.24608580271402 Epoch 18/30, Training Loss: 45.82250001695421, Validation Loss: 197.4443629582723 Epoch 19/30, Training Loss: 44.56912511189778, Validation Loss: 195.75810464223227 Epoch 20/30, Training Loss: 44.55307981703017, Validation Loss: 196.81069056193033 Epoch 21/30, Training Loss: 43.780117713080514, Validation Loss: 190.78466955820718 Epoch 22/30, Training Loss: 43.49141337076823, Validation Loss: 193.90233008066812 Epoch 23/30, Training Loss: 43.6044068230523, Validation Loss: 200.35870520273843 Epoch 24/30, Training Loss: 43.11416583591037, Validation Loss: 198.29298496246338 Epoch 25/30, Training Loss: 42.296264224582245, Validation Loss: 196.31312561035156 Epoch 26/30, Training Loss: 42.06908925374349, Validation Loss: 198.18200302124023 Epoch 27/30, Training Loss: 42.08123363918728, Validation Loss: 197.08484395345053 Epoch 28/30, Training Loss: 41.53609085083008, Validation Loss: 196.21040725708008 Epoch 29/30, Training Loss: 41.187508307562936, Validation Loss: 194.32446511586508 Epoch 30/30, Training Loss: 41.22317564222548, Validation Loss: 199.1661138534546 Training complete
Training for hidden_size=200, blocks = 10 Epoch 1/30, Training Loss: 8550.333620876736, Validation Loss: 252.25765419006348 Epoch 2/30, Training Loss: 64.73046518961588, Validation Loss: 221.92402013142905 Epoch 3/30, Training Loss: 57.476726362440324, Validation Loss: 208.8302043279012 Epoch 4/30, Training Loss: 53.80047760009766, Validation Loss: 204.94884077707925 Epoch 5/30, Training Loss: 51.54618767632378, Validation Loss: 199.60729948679605 Epoch 6/30, Training Loss: 49.243514166937935, Validation Loss: 195.52471828460693 Epoch 7/30, Training Loss: 47.38635228474935, Validation Loss: 190.7220137914022 Epoch 8/30, Training Loss: 46.45990176730686, Validation Loss: 198.09132544199625 Epoch 9/30, Training Loss: 45.40678176879883, Validation Loss: 190.19877115885416 Epoch 10/30, Training Loss: 44.32002614339193, Validation Loss: 190.01830673217773 Epoch 11/30, Training Loss: 43.25411470201281, Validation Loss: 188.1542704900106 Epoch 12/30, Training Loss: 42.34379747178819, Validation Loss: 189.91027164459229 Epoch 13/30, Training Loss: 41.8846063401964, Validation Loss: 185.82088088989258 Epoch 14/30, Training Loss: 41.15804562038846, Validation Loss: 190.71087487538657 Epoch 15/30, Training Loss: 40.05203467475043, Validation Loss: 187.0631825129191 Epoch 16/30, Training Loss: 39.5969729953342, Validation Loss: 185.98451582590738 Epoch 17/30, Training Loss: 39.40454610188802, Validation Loss: 187.34200191497803 Epoch 18/30, Training Loss: 39.446832275390626, Validation Loss: 184.95723565419516 Epoch 19/30, Training Loss: 38.229959615071614, Validation Loss: 183.16800848642984 Epoch 20/30, Training Loss: 38.03151304456923, Validation Loss: 187.79945786794028 Epoch 21/30, Training Loss: 37.87414449055989, Validation Loss: 190.2610190709432 Epoch 22/30, Training Loss: 37.65578426784939, Validation Loss: 185.9701935450236 Epoch 23/30, Training Loss: 36.962117513020836, Validation Loss: 188.50027306874594 Epoch 24/30, Training Loss: 36.36226069132487, Validation Loss: 189.6880203882853 Epoch 25/30, Training Loss: 36.388100941975914, Validation Loss: 190.7065750757853 Epoch 26/30, Training Loss: 36.236502668592664, Validation Loss: 191.58643309275308 Epoch 27/30, Training Loss: 35.916646321614586, Validation Loss: 191.9307139714559 Epoch 28/30, Training Loss: 35.55673332214356, Validation Loss: 189.29961744944254 Epoch 29/30, Training Loss: 35.55209545559353, Validation Loss: 189.1108185450236 Epoch 30/30, Training Loss: 35.0932319217258, Validation Loss: 184.86688454945883 Training complete
Training for hidden_size=400, blocks = 2 Epoch 1/30, Training Loss: 309.8460605197483, Validation Loss: 641.5680033365885 Epoch 2/30, Training Loss: 167.27688564724392, Validation Loss: 550.9989280700684 Epoch 3/30, Training Loss: 155.16913079155816, Validation Loss: 552.8562787373861 Epoch 4/30, Training Loss: 142.22390001085068, Validation Loss: 491.44150416056317 Epoch 5/30, Training Loss: 137.04273817274304, Validation Loss: 793.9056180318197 Epoch 6/30, Training Loss: 158.23646189371746, Validation Loss: 612.9892018636068 Epoch 7/30, Training Loss: 138.3919923570421, Validation Loss: 442.6516710917155 Epoch 8/30, Training Loss: 133.79639689127603, Validation Loss: 836.4147109985352 Epoch 9/30, Training Loss: 146.32217983669705, Validation Loss: 468.4177182515462 Epoch 10/30, Training Loss: 137.2965772840712, Validation Loss: 465.1561864217122 Epoch 11/30, Training Loss: 142.33085208468967, Validation Loss: 486.0627187093099 Epoch 12/30, Training Loss: 129.43356441921657, Validation Loss: 419.21056874593097 Epoch 13/30, Training Loss: 122.93208906385634, Validation Loss: 428.186274210612 Epoch 14/30, Training Loss: 132.58959350585937, Validation Loss: 463.64319864908856 Epoch 15/30, Training Loss: 113.37272321912977, Validation Loss: 377.0710964202881 Epoch 16/30, Training Loss: 117.27298194037543, Validation Loss: 417.78064982096356 Epoch 17/30, Training Loss: 129.6254869249132, Validation Loss: 422.8923905690511 Epoch 18/30, Training Loss: 112.74502393934462, Validation Loss: 425.263973236084 Epoch 19/30, Training Loss: 114.95833892822266, Validation Loss: 388.4588680267334 Epoch 20/30, Training Loss: 116.57845085991754, Validation Loss: 398.36245282491046 Epoch 21/30, Training Loss: 107.6398440890842, Validation Loss: 395.6511694590251 Epoch 22/30, Training Loss: 114.73315056694878, Validation Loss: 365.79894574483234 Epoch 23/30, Training Loss: 153.03194698757596, Validation Loss: 548.2154261271158 Epoch 24/30, Training Loss: 120.47216135660807, Validation Loss: 377.80611483256024 Epoch 25/30, Training Loss: 121.93333214653863, Validation Loss: 359.2391185760498 Epoch 26/30, Training Loss: 107.88890940348307, Validation Loss: 388.6443614959717 Epoch 27/30, Training Loss: 105.48415205213759, Validation Loss: 394.25697898864746 Epoch 28/30, Training Loss: 103.53916642930773, Validation Loss: 359.95915667215985 Epoch 29/30, Training Loss: 103.76796654595269, Validation Loss: 404.914644241333 Epoch 30/30, Training Loss: 101.15233103434245, Validation Loss: 371.7300599416097 Training complete
Training for hidden_size=400, blocks = 5 Epoch 1/30, Training Loss: 283.9519036187066, Validation Loss: 306.4883219401042 Epoch 2/30, Training Loss: 78.9733864678277, Validation Loss: 267.76504071553546 Epoch 3/30, Training Loss: 71.24371761745877, Validation Loss: 249.29519907633463 Epoch 4/30, Training Loss: 66.88938734266493, Validation Loss: 242.08064715067545 Epoch 5/30, Training Loss: 64.17731535169813, Validation Loss: 234.6187343597412 Epoch 6/30, Training Loss: 61.6169075012207, Validation Loss: 232.41984430948892 Epoch 7/30, Training Loss: 61.058763631184895, Validation Loss: 227.2152500152588 Epoch 8/30, Training Loss: 58.36835199991862, Validation Loss: 221.61338710784912 Epoch 9/30, Training Loss: 57.72528330485026, Validation Loss: 223.6118532816569 Epoch 10/30, Training Loss: 56.28718982272678, Validation Loss: 223.97171783447266 Epoch 11/30, Training Loss: 56.164582061767575, Validation Loss: 217.72104358673096 Epoch 12/30, Training Loss: 55.18229760064019, Validation Loss: 218.48025862375894 Epoch 13/30, Training Loss: 54.764756774902345, Validation Loss: 226.10987981160483 Epoch 14/30, Training Loss: 54.429725138346356, Validation Loss: 217.99609343210855 Epoch 15/30, Training Loss: 52.93662999471029, Validation Loss: 219.3150078455607 Epoch 16/30, Training Loss: 52.57396011352539, Validation Loss: 219.934596379598 Epoch 17/30, Training Loss: 52.49516406589084, Validation Loss: 221.11824067433676 Epoch 18/30, Training Loss: 51.70450863308377, Validation Loss: 213.79633712768555 Epoch 19/30, Training Loss: 51.20831722683377, Validation Loss: 216.88543923695883 Epoch 20/30, Training Loss: 51.07510291205512, Validation Loss: 217.75438944498697 Epoch 21/30, Training Loss: 50.03879080878364, Validation Loss: 216.98380406697592 Epoch 22/30, Training Loss: 49.473376210530596, Validation Loss: 220.66783332824707 Epoch 23/30, Training Loss: 49.485553656684026, Validation Loss: 218.93354606628418 Epoch 24/30, Training Loss: 49.50805435180664, Validation Loss: 215.00685501098633 Epoch 25/30, Training Loss: 48.39864205254449, Validation Loss: 218.3308277130127 Epoch 26/30, Training Loss: 48.350467936197916, Validation Loss: 217.81539980570474 Epoch 27/30, Training Loss: 48.22176903618707, Validation Loss: 216.50745010375977 Epoch 28/30, Training Loss: 47.75731006198459, Validation Loss: 213.8287207285563 Epoch 29/30, Training Loss: 47.470498402913414, Validation Loss: 216.02937698364258 Epoch 30/30, Training Loss: 46.7009398566352, Validation Loss: 221.29212951660156 Training complete
Training for hidden_size=400, blocks = 10 Epoch 1/30, Training Loss: 1178.0424274020725, Validation Loss: 314.2991199493408 Epoch 2/30, Training Loss: 80.48548092312284, Validation Loss: 272.5304183959961 Epoch 3/30, Training Loss: 71.183130730523, Validation Loss: 252.77461687723795 Epoch 4/30, Training Loss: 67.16154123942057, Validation Loss: 247.56891632080078 Epoch 5/30, Training Loss: 63.57256266276042, Validation Loss: 238.0752493540446 Epoch 6/30, Training Loss: 61.470615810818146, Validation Loss: 232.28548304239908 Epoch 7/30, Training Loss: 59.5168451944987, Validation Loss: 230.8792053858439 Epoch 8/30, Training Loss: 58.22029656304254, Validation Loss: 229.1913324991862 Epoch 9/30, Training Loss: 57.4977183871799, Validation Loss: 223.06117502848306 Epoch 10/30, Training Loss: 56.045141516791446, Validation Loss: 226.03118069966635 Epoch 11/30, Training Loss: 55.24988725450304, Validation Loss: 221.1461903254191 Epoch 12/30, Training Loss: 55.002064853244356, Validation Loss: 221.07627709706625 Epoch 13/30, Training Loss: 53.88930994669597, Validation Loss: 212.35111586252847 Epoch 14/30, Training Loss: 52.65658925374349, Validation Loss: 217.0702174504598 Epoch 15/30, Training Loss: 52.393392605251734, Validation Loss: 215.39574591318765 Epoch 16/30, Training Loss: 51.5170295715332, Validation Loss: 216.01587708791098 Epoch 17/30, Training Loss: 51.10721206665039, Validation Loss: 216.7215223312378 Epoch 18/30, Training Loss: 50.49084082709418, Validation Loss: 214.5624647140503 Epoch 19/30, Training Loss: 50.000826856825086, Validation Loss: 219.19682820638022 Epoch 20/30, Training Loss: 49.75510025024414, Validation Loss: 215.7393010457357 Epoch 21/30, Training Loss: 49.33414815266927, Validation Loss: 217.59172598520914 Epoch 22/30, Training Loss: 49.12450781928168, Validation Loss: 212.6950419743856 Epoch 23/30, Training Loss: 48.87369783189562, Validation Loss: 218.82337061564127 Epoch 24/30, Training Loss: 48.36188820732964, Validation Loss: 214.2415647506714 Epoch 25/30, Training Loss: 48.280346086290145, Validation Loss: 215.3443225224813 Epoch 26/30, Training Loss: 48.694633399115666, Validation Loss: 217.40858713785806 Epoch 27/30, Training Loss: 48.18829659356011, Validation Loss: 221.2189852396647 Epoch 28/30, Training Loss: 47.713622368706595, Validation Loss: 213.31340758005777 Epoch 29/30, Training Loss: 47.00494562784831, Validation Loss: 217.6350638071696 Epoch 30/30, Training Loss: 47.01122521294488, Validation Loss: 212.82092920939127 Training complete
We see we still get the best results for the same network structure, so we are going to stick with that and the same training hyperparameters
### conditional real NVP class
class ConditionalRealNVP_bottleneck(nn.Module):
def __init__(self, input_size, hidden_size, condition_size, blocks, k):
"""
Initialize a ConditionalRealNVP model.
Args:
- input_size (int): Total size of the input data.
- hidden_size (int): Size of the hidden layers in the neural networks.
- condition_size (int): Size of the condition vector (e.g., one-hot encoded label size).
- blocks (int): Number of coupling layers in the model.
"""
super(ConditionalRealNVP_bottleneck, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.condition_size = condition_size
self.blocks = blocks
self.k = k
# List of coupling layers
self.coupling_layers = nn.ModuleList([
ConditionalCouplingLayer(input_size, hidden_size, condition_size) for _ in range(blocks)
])
# List to store orthonormal matrices
self.orthonormal_matrices = [self._get_orthonormal_matrix(input_size) for _ in range(blocks)]
# List to store scaling_before_exp for each block
self.scaling_before_exp_list = []
def _get_orthonormal_matrix(self, size):
"""
Generate a random orthonormal matrix.
Args:
- size (int): Size of the matrix.
Returns:
- q (torch.Tensor): Orthonormal matrix.
"""
w = torch.randn(size, size)
q, _ = torch.linalg.qr(w, 'reduced')
return q
def forward_realnvp(self, x, condition):
"""
Forward pass through the ConditionalRealNVP model.
Args:
- x (torch.Tensor): Input data.
- condition (torch.Tensor): Condition vector.
Returns:
- x (torch.Tensor): Transformed data.
"""
scaling_before_exp_list = []
for i in range(self.blocks):
#print("x is:"); print(x)
#print("shape of x is:"); print(x.shape)
x = torch.matmul(x, self.orthonormal_matrices[i])
x, scaling_before_exp = self.coupling_layers[i].forward(x, condition)
scaling_before_exp_list.append(scaling_before_exp)
self.scaling_before_exp_list = scaling_before_exp_list
return x
def decode(self, z, condition):
# Modify z to zero out dimensions beyond k for the reconstruction
z_reconstructed = z.clone()
if self.k < self.input_size:
z_reconstructed[:, self.k:] = 0 # Zero out dimensions beyond k
# Proceed with the original decoding process
for i in reversed(range(self.blocks)):
z = self.coupling_layers[i].backward(z, condition)
z_reconstructed = self.coupling_layers[i].backward(z_reconstructed, condition)
z = torch.matmul(z, self.orthonormal_matrices[i].t())
z_reconstructed = torch.matmul(z_reconstructed, self.orthonormal_matrices[i].t())
return z, z_reconstructed
def sample(self, num_samples=1000, conditions=None):
"""
Generate synthetic samples.
Args:
- num_samples (int): Number of synthetic samples to generate.
- conditions (torch.Tensor): Conditions for generating synthetic samples.
Returns:
- synthetic_samples (torch.Tensor): Synthetic samples.
"""
with torch.no_grad():
z = torch.randn(num_samples, self.input_size)
synthetic_samples, _ = self.decode(z, conditions)
return synthetic_samples
def sample_only_important(self, num_samples=1000, conditions=None):
# Generate random samples from a standard normal distribution
with torch.no_grad():
z_1 = torch.randn(num_samples, self.k)
z_2 = torch.zeros(num_samples, self.input_size - self.k)
z = torch.cat((z_1, z_2), dim=1)
# Apply the reverse transformations (decoder) to generate synthetic samples
synthetic_samples, _ = self.decode(z, conditions)
return synthetic_samples
def sample_only_unimportant(self, num_samples=1000, conditions=None):
# Generate random samples from a standard normal distribution
with torch.no_grad():
z_1 = torch.randn(1, self.k).repeat(num_samples, 1)
z_2 = torch.randn(num_samples, self.input_size - self.k)
z = torch.cat((z_1, z_2), dim=1)
# Apply the reverse transformations (decoder) to generate synthetic samples
synthetic_samples, _ = self.decode(z, conditions)
return synthetic_samples
### training_the_conditional_nvp model
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
def train_and_validate_conditional_nvp_bottleneck(model, train_loader, val_loader, num_epochs=10, lr=0.001, print_after=1):
"""
Train the ConditionalRealNVP model and evaluate on a validation dataset.
Args:
- model (ConditionalRealNVP): The ConditionalRealNVP model to be trained.
- train_loader (DataLoader): DataLoader for the training dataset.
- val_loader (DataLoader): DataLoader for the validation dataset.
- num_epochs (int): Number of training epochs.
- lr (float): Learning rate for the optimizer.
- print_after (int): Number of epochs after which to print the training and validation loss.
Returns:
- train_losses (list): List of training losses for each epoch.
- val_losses (list): List of validation losses for each epoch.
"""
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
mse_loss = nn.MSELoss()
train_losses_nll = []
val_losses_nll = []
train_losses_recons = []
val_losses_recons = []
# Training phase
model.train() # Set the model to training mode
for epoch in range(num_epochs):
total_train_loss_nll = 0.0
total_train_loss_recons = 0.0
for data, labels in train_loader:
inputs = data
conditions = one_hot(labels, num_classes=model.condition_size).float()
# Zero the gradients
optimizer.zero_grad()
# Forward pass (encoding)
encoded = model.forward_realnvp(inputs, conditions)
train_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))
# Reconstruction loss calculation
_, decoded = model.decode(encoded, conditions)
train_loss_recons = mse_loss(inputs, decoded)
# Backward pass (gradient computation)
loss = train_loss_nll + train_loss_recons
loss.backward()
### added recently: clip the gradients
clip_grad_norm_(model.parameters(), max_norm=1.0) # Adjust max_norm as needed
# Update weights
optimizer.step()
total_train_loss_nll += train_loss_nll.item()
total_train_loss_recons += train_loss_recons.item()
# Average training loss for the epoch
average_train_loss_nll = total_train_loss_nll / len(train_loader)
average_train_loss_recons = total_train_loss_recons / len(train_loader)
# Validation phase
model.eval() # Set the model to evaluation mode
if val_loader is not None:
total_val_loss_nll = 0.0
total_val_loss_recons = 0.0
with torch.no_grad():
for val_data, val_labels in val_loader:
val_inputs = val_data
val_conditions = one_hot(val_labels, num_classes=model.condition_size).float()
# Forward pass (encoding) for validation
val_encoded = model.forward_realnvp(val_inputs, val_conditions)
# NLL Loss calculation
val_loss_nll = calculate_loss(val_encoded, model.scaling_before_exp_list, len(val_loader))
# Reconstruction loss calculation
_, decoded = model.decode(val_encoded, val_conditions)
val_loss_recons = mse_loss(val_inputs, decoded)
total_val_loss_nll += val_loss_nll.item()
total_val_loss_recons += val_loss_recons.item()
# Average validation loss for the epoch
average_val_loss_nll = total_val_loss_nll / len(val_loader)
average_val_loss_recons = total_val_loss_recons / len(val_loader)
# Print training and validation losses together
if (epoch + 1) % print_after == 0:
print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss_recons+average_train_loss_nll}, Validation Loss: {average_val_loss_nll+average_val_loss_recons}")
# Append losses to the lists
train_losses_nll.append(average_train_loss_nll)
val_losses_nll.append(average_val_loss_nll)
train_losses_recons.append(average_train_loss_recons)
val_losses_recons.append(average_val_loss_recons)
print("Training complete")
return train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons
k_values = [2,4,8]
dataset_percentage = 1.0
for k in k_values:
print(f"\nTraining for k={k}")
# Instantiate the model
model = ConditionalRealNVP_bottleneck(input_size=64, hidden_size=100, blocks=10,condition_size=10,k=k)
# Train the model
train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_validate_conditional_nvp_bottleneck(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1)
train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)
# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()
### plot the synthetic data
conditions_all_labels = torch.eye(condition_size)
synthetic_data=model.sample_only_important(num_samples=10,conditions= conditions_all_labels)
visualize_synthetic_data(synthetic_data, title="Sampling only important features")
conditions_first_elements = torch.zeros((10, 10))
conditions_first_elements[:,0] = 1
synthetic_data=model.sample_only_unimportant(num_samples=10, conditions=conditions_first_elements)
visualize_synthetic_data(synthetic_data, title="Sampling only unimportant features")
Training for k=2 Epoch 1/20, Training Loss: 4268.482129012214, Validation Loss: 277.0808455944061 Epoch 2/20, Training Loss: 79.93924753401015, Validation Loss: 238.11668674151102 Epoch 3/20, Training Loss: 70.49177271525065, Validation Loss: 220.5944479306539 Epoch 4/20, Training Loss: 65.97008811102974, Validation Loss: 212.06634493668875 Epoch 5/20, Training Loss: 62.57478442721897, Validation Loss: 206.0178082784017 Epoch 6/20, Training Loss: 60.040938536326095, Validation Loss: 206.48840761184692 Epoch 7/20, Training Loss: 58.67695316738553, Validation Loss: 199.03116782506308 Epoch 8/20, Training Loss: 56.43557636472914, Validation Loss: 194.44701397418976 Epoch 9/20, Training Loss: 55.055921967824304, Validation Loss: 194.8996948401133 Epoch 10/20, Training Loss: 53.81541895336575, Validation Loss: 192.07523147265118 Epoch 11/20, Training Loss: 52.73026489681668, Validation Loss: 193.12475633621216 Epoch 12/20, Training Loss: 52.082277584075925, Validation Loss: 188.29537717501324 Epoch 13/20, Training Loss: 51.161164898342555, Validation Loss: 191.1772176027298 Epoch 14/20, Training Loss: 50.40476257536147, Validation Loss: 187.4039184252421 Epoch 15/20, Training Loss: 49.79507061640422, Validation Loss: 189.2285165389379 Epoch 16/20, Training Loss: 48.9830634329054, Validation Loss: 187.52914067109427 Epoch 17/20, Training Loss: 48.31131310992771, Validation Loss: 189.55784090360004 Epoch 18/20, Training Loss: 47.838651455773245, Validation Loss: 183.39785277843475 Epoch 19/20, Training Loss: 47.57893569734362, Validation Loss: 185.94831204414368 Epoch 20/20, Training Loss: 46.873543463812936, Validation Loss: 185.71018425623575 Training complete
Training for k=4 Epoch 1/20, Training Loss: 1531363.1183485244, Validation Loss: 1611.2183237075806 Epoch 2/20, Training Loss: 145.63676013946534, Validation Loss: 285.2107837994894 Epoch 3/20, Training Loss: 81.61559757656522, Validation Loss: 245.5362807114919 Epoch 4/20, Training Loss: 71.20380795796711, Validation Loss: 225.57799275716147 Epoch 5/20, Training Loss: 65.88819755978054, Validation Loss: 213.04468441009521 Epoch 6/20, Training Loss: 61.858477687835695, Validation Loss: 206.82857819398242 Epoch 7/20, Training Loss: 59.467446655697294, Validation Loss: 203.04642629623413 Epoch 8/20, Training Loss: 56.79460786183675, Validation Loss: 198.47188929716745 Epoch 9/20, Training Loss: 55.198437235090466, Validation Loss: 194.5198189020157 Epoch 10/20, Training Loss: 53.73772185643514, Validation Loss: 194.05041245619455 Epoch 11/20, Training Loss: 52.30835009680854, Validation Loss: 189.8688794374466 Epoch 12/20, Training Loss: 51.18810695012411, Validation Loss: 188.23876953125 Epoch 13/20, Training Loss: 50.2093313852946, Validation Loss: 185.54253792762756 Epoch 14/20, Training Loss: 49.52944510777791, Validation Loss: 186.65644093354544 Epoch 15/20, Training Loss: 48.634717538621686, Validation Loss: 184.9084700345993 Epoch 16/20, Training Loss: 48.00103391011556, Validation Loss: 186.0170479218165 Epoch 17/20, Training Loss: 47.37019804848565, Validation Loss: 185.29169126351675 Epoch 18/20, Training Loss: 46.76975990931193, Validation Loss: 182.8655904928843 Epoch 19/20, Training Loss: 45.86354028913709, Validation Loss: 186.44027853012085 Epoch 20/20, Training Loss: 46.205892425113255, Validation Loss: 184.22301808993024 Training complete
Training for k=8 Epoch 1/20, Training Loss: 12503657.154557291, Validation Loss: 57410.58361816406 Epoch 2/20, Training Loss: 2766.1257029215494, Validation Loss: 341.8003800710042 Epoch 3/20, Training Loss: 92.67972922854953, Validation Loss: 261.6807294686635 Epoch 4/20, Training Loss: 75.06449168523153, Validation Loss: 230.45265928904217 Epoch 5/20, Training Loss: 67.8765141805013, Validation Loss: 218.10639572143555 Epoch 6/20, Training Loss: 63.224973074595134, Validation Loss: 211.41110841433206 Epoch 7/20, Training Loss: 60.210762935214575, Validation Loss: 208.3572313785553 Epoch 8/20, Training Loss: 58.03645001517402, Validation Loss: 199.64385946591696 Epoch 9/20, Training Loss: 56.108029672834604, Validation Loss: 195.0037250916163 Epoch 10/20, Training Loss: 54.359071456061464, Validation Loss: 193.970672527949 Epoch 11/20, Training Loss: 53.008008819156224, Validation Loss: 192.40503108501434 Epoch 12/20, Training Loss: 51.81032321718004, Validation Loss: 190.76259569327038 Epoch 13/20, Training Loss: 50.47278265423245, Validation Loss: 186.2374519109726 Epoch 14/20, Training Loss: 49.65340761608548, Validation Loss: 189.4589294989904 Epoch 15/20, Training Loss: 49.08715744548374, Validation Loss: 185.34620114167532 Epoch 16/20, Training Loss: 48.26118794547187, Validation Loss: 183.30718386173248 Epoch 17/20, Training Loss: 47.3154987970988, Validation Loss: 183.28789893786114 Epoch 18/20, Training Loss: 46.86230368084377, Validation Loss: 185.00911617279053 Epoch 19/20, Training Loss: 46.01728831926982, Validation Loss: 182.14698894818622 Epoch 20/20, Training Loss: 46.02813222673204, Validation Loss: 183.1218525568644 Training complete
# Instantiate the model
model = ConditionalRealNVP_bottleneck(input_size=64, hidden_size=100, blocks=10,condition_size=10,k=2)
# Train the model
train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_validate_conditional_nvp_bottleneck(model, train_loader, val_loader, num_epochs=20, lr=0.005, print_after=1)
train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)
# plotting the loss
plot_losses(train_losses[1:], val_losses[1:], want_log_scale=0)
plt.show()
Epoch 1/20, Training Loss: 21580.730417421128, Validation Loss: 297.31214563051856 Epoch 2/20, Training Loss: 83.47490584055582, Validation Loss: 243.01717726389566 Epoch 3/20, Training Loss: 71.68545961380005, Validation Loss: 224.35330470403036 Epoch 4/20, Training Loss: 65.9989339404636, Validation Loss: 214.9265724023183 Epoch 5/20, Training Loss: 62.00650640063815, Validation Loss: 203.16986227035522 Epoch 6/20, Training Loss: 59.44165139728122, Validation Loss: 200.70353790124258 Epoch 7/20, Training Loss: 57.68972770902845, Validation Loss: 196.80060239632925 Epoch 8/20, Training Loss: 55.89540031221178, Validation Loss: 196.14737010002136 Epoch 9/20, Training Loss: 54.81338379118178, Validation Loss: 193.17052300771078 Epoch 10/20, Training Loss: 53.11203460693359, Validation Loss: 190.24123994509378 Epoch 11/20, Training Loss: 52.03209048377143, Validation Loss: 187.1682772239049 Epoch 12/20, Training Loss: 51.64189916186862, Validation Loss: 189.73837987581888 Epoch 13/20, Training Loss: 50.37107356389364, Validation Loss: 184.49492673079175 Epoch 14/20, Training Loss: 49.83295380274455, Validation Loss: 184.8419489065806 Epoch 15/20, Training Loss: 49.164054001702205, Validation Loss: 184.1409958998362 Epoch 16/20, Training Loss: 49.043845907847086, Validation Loss: 185.45416287581125 Epoch 17/20, Training Loss: 48.46328016916911, Validation Loss: 184.55666601657867 Epoch 18/20, Training Loss: 47.334013080596925, Validation Loss: 182.43773651123047 Epoch 19/20, Training Loss: 47.01087157991197, Validation Loss: 185.98060782750449 Epoch 20/20, Training Loss: 46.53891531626384, Validation Loss: 187.26162362098694 Training complete
Here the results from k = 2 look the best. We see that the generated images look like the numbers we tried to generate. Lets test the quality with a Random forest classifier
from sklearn.ensemble import RandomForestClassifier
rf_classifier = RandomForestClassifier(n_estimators=100)
rf_classifier.fit(X_train, y_train)
RandomForestClassifier()In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
RandomForestClassifier()
### plot the synthetic data
from sklearn.metrics import accuracy_score
ground_truth = torch.eye(condition_size).repeat(100,1)
ground_truth_labels = np.argmax(ground_truth, axis=1)
synthetic_data=model.sample(num_samples=1000,conditions=ground_truth)
print(synthetic_data.size())
predictions = rf_classifier.predict(synthetic_data)
accuracy = accuracy_score(ground_truth_labels, predictions)
print(f'Accuracy = {accuracy}')
torch.Size([1000, 64]) Accuracy = 0.868
We see the Random Forest Classifier achieves a reasonable accuracy of 0.87!
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image, ImageFilter # install 'pillow ' to get PIL
import matplotlib.pyplot as plt
# define a functor to downsample images
class DownsampleTransform:
def __init__ (self , target_shape , algorithm = Image.Resampling.LANCZOS):
self.width, self.height = target_shape
self.algorithm = algorithm
def __call__ (self , img):
img = img.resize(( self.width +2, self.height +2) , self.algorithm )
img = img.crop((1 , 1, self.width +1, self.height +1))
return img
# concatenate a few transforms
transform = transforms.Compose ([
DownsampleTransform(target_shape =(8 ,8)),
transforms.Grayscale(num_output_channels =1) ,
transforms.ToTensor()
])
# download MNIST
mnist_dataset = datasets.MNIST( root ='./data', train =True ,
transform = transform, download = True )
# create a DataLoader that serves minibatches of size 100
data_loader = DataLoader(mnist_dataset , batch_size =100 , shuffle = True )
mnist_test_dataset = datasets.MNIST( root ='./data', train =False ,
transform = transform, download = True )
val_loader = DataLoader(mnist_test_dataset , batch_size =100 , shuffle = True )
# visualize the first batch of downsampled MNIST images
def show_first_batch(data_loader):
for batch in data_loader:
x, y = batch
fig = plt.figure(figsize =(10 , 10))
for i, img in enumerate(x):
ax = fig.add_subplot(10 , 10, i+1)
ax.imshow(img.reshape(8, 8), cmap ='gray')
ax.axis('off')
break
show_first_batch(data_loader)
### training_the_conditional_nvp model
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
def train_and_validate_conditional_nvp(model, train_loader, val_loader, num_epochs=10, lr=0.001, print_after=1):
"""
Train the ConditionalRealNVP model and evaluate on a validation dataset.
Args:
- model (ConditionalRealNVP): The ConditionalRealNVP model to be trained.
- train_loader (DataLoader): DataLoader for the training dataset.
- val_loader (DataLoader): DataLoader for the validation dataset.
- num_epochs (int): Number of training epochs.
- lr (float): Learning rate for the optimizer.
- print_after (int): Number of epochs after which to print the training and validation loss.
Returns:
- train_losses (list): List of training losses for each epoch.
- val_losses (list): List of validation losses for each epoch.
"""
# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
mse_loss = nn.MSELoss()
train_losses_nll = []
val_losses_nll = []
train_losses_recons = []
val_losses_recons = []
# Training phase
model.train() # Set the model to training mode
for epoch in range(num_epochs):
total_train_loss_nll = 0.0
total_train_loss_recons = 0.0
for data, labels in train_loader:
inputs = data.reshape(len(labels),64)
conditions = one_hot(labels, num_classes=model.condition_size).float()
# Zero the gradients
optimizer.zero_grad()
# Forward pass (encoding)
encoded = model.forward_realnvp(inputs, conditions)
train_loss_nll = calculate_loss(encoded, model.scaling_before_exp_list, len(train_loader))
# Reconstruction loss calculation
decoded = model.decode(encoded, conditions)
train_loss_recons = mse_loss(inputs, decoded)
# Backward pass (gradient computation)
loss = train_loss_nll + train_loss_recons
loss.backward()
### added recently: clip the gradients
clip_grad_norm_(model.parameters(), max_norm=1.0) # Adjust max_norm as needed
# Update weights
optimizer.step()
total_train_loss_nll += train_loss_nll.item()
total_train_loss_recons += train_loss_recons.item()
# Average training loss for the epoch
average_train_loss_nll = total_train_loss_nll / len(train_loader)
average_train_loss_recons = total_train_loss_recons / len(train_loader)
# Validation phase
model.eval() # Set the model to evaluation mode
if val_loader is not None:
total_val_loss_nll = 0.0
total_val_loss_recons = 0.0
with torch.no_grad():
for val_data, val_labels in val_loader:
val_inputs = val_data.reshape(len(labels),64)
val_conditions = one_hot(val_labels, num_classes=model.condition_size).float()
# Forward pass (encoding) for validation
val_encoded = model.forward_realnvp(val_inputs, val_conditions)
# NLL Loss calculation
val_loss_nll = calculate_loss(val_encoded, model.scaling_before_exp_list, len(val_loader))
# Reconstruction loss calculation
decoded = model.decode(val_encoded, val_conditions)
val_loss_recons = mse_loss(val_inputs, decoded)
total_val_loss_nll += val_loss_nll.item()
total_val_loss_recons += val_loss_recons.item()
# Average validation loss for the epoch
average_val_loss_nll = total_val_loss_nll / len(val_loader)
average_val_loss_recons = total_val_loss_recons / len(val_loader)
# Print training and validation losses together
if (epoch + 1) % print_after == 0:
print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_train_loss_recons+average_train_loss_nll}, Validation Loss: {average_val_loss_nll+average_val_loss_recons}")
# Append losses to the lists
train_losses_nll.append(average_train_loss_nll)
val_losses_nll.append(average_val_loss_nll)
train_losses_recons.append(average_train_loss_recons)
val_losses_recons.append(average_val_loss_recons)
print("Training complete")
return train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons
input_size = 64
hidden_size = 100
blocks = 10
print_after=1
dataset_percentage = 1.0
batch_size=100
# Instantiate the model
model = ConditionalRealNVP(input_size=64, hidden_size=100, blocks=10,condition_size=10)
# Train the model
train_losses_nll, train_losses_recons, val_losses_nll, val_losses_recons = train_and_validate_conditional_nvp(model, data_loader, val_loader, num_epochs=10, lr=0.005, print_after=1)
train_losses = np.array(train_losses_nll) + np.array(train_losses_recons)
val_losses = np.array(val_losses_nll) + np.array(val_losses_recons)
# plotting the loss
plot_losses(train_losses, val_losses, want_log_scale=0)
plt.show()
### plot the synthetic data
conditions_all_labels = torch.eye(condition_size)
synthetic_data=conditional_inn_model.sample(num_samples=10, conditions= conditions_all_labels)
visualize_synthetic_data(synthetic_data,title="Synthetic digits from 0 to 9")
plt.show()
Epoch 1/10, Training Loss: -19.08519719739742, Validation Loss: -130.75578613281235 Epoch 2/10, Training Loss: -22.443291050592926, Validation Loss: -137.0801293945311 Epoch 3/10, Training Loss: -23.36428718566881, Validation Loss: -141.2957122802733 Epoch 4/10, Training Loss: -23.87661533991482, Validation Loss: -144.08587219238268 Epoch 5/10, Training Loss: -24.201026268005233, Validation Loss: -145.4849877929686 Epoch 6/10, Training Loss: -24.45872751235948, Validation Loss: -147.0356079101561 Epoch 7/10, Training Loss: -24.656676244735575, Validation Loss: -147.69472976684557 Epoch 8/10, Training Loss: -24.829769868850565, Validation Loss: -149.69698562622057 Epoch 9/10, Training Loss: -24.958654168446717, Validation Loss: -150.0165017700194 Epoch 10/10, Training Loss: -25.043524036407327, Validation Loss: -149.860236968994 Training complete
This using the MNIST dataset yields far worse results than the bottleneck approach,but this might be because longer training is needed.